|
155 | 155 | is_ipex_available,
|
156 | 156 | is_lomo_available,
|
157 | 157 | is_peft_available,
|
| 158 | + is_q_galore_torch_available, |
158 | 159 | is_safetensors_available,
|
159 | 160 | is_sagemaker_dp_enabled,
|
160 | 161 | is_sagemaker_mp_enabled,
|
@@ -1288,6 +1289,132 @@ def get_optimizer_cls_and_kwargs(
|
1288 | 1289 | optimizer_cls = torch.optim.Adagrad
|
1289 | 1290 | elif args.optim == OptimizerNames.RMSPROP:
|
1290 | 1291 | optimizer_cls = torch.optim.RMSprop
|
| 1292 | + elif args.optim in [OptimizerNames.QGALORE_ADAMW_8BIT, OptimizerNames.QGALORE_ADAMW_8BIT_LAYERWISE]: |
| 1293 | + if not is_q_galore_torch_available(): |
| 1294 | + raise ImportError( |
| 1295 | + "You need to install `q-galore-torch` in order to use GaLore optimizers" |
| 1296 | + " install it with `pip install qgalore" |
| 1297 | + ) |
| 1298 | + from q_galore_torch import QGaLoreAdamW8bit |
| 1299 | + |
| 1300 | + is_layerwise = args.optim.lower().endswith("layerwise") |
| 1301 | + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 1302 | + # TODO: check if this is True |
| 1303 | + raise NotImplementedError("Layer-wise QGaLore does not support DDP at this time") |
| 1304 | + |
| 1305 | + optimizer_cls = QGaLoreAdamW8bit |
| 1306 | + |
| 1307 | + if args.optim_target_modules is None: |
| 1308 | + raise ValueError( |
| 1309 | + "You need to define a `optim_target_modules` in order to properly use QGaLore optimizers" |
| 1310 | + ) |
| 1311 | + if args.optim_target_modules is None: |
| 1312 | + raise ValueError( |
| 1313 | + "You need to define a `optim_target_modules` in order to properly use GaLore optimizers" |
| 1314 | + ) |
| 1315 | + |
| 1316 | + if not isinstance(args.optim_target_modules, (list, str)): |
| 1317 | + raise ValueError( |
| 1318 | + f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" |
| 1319 | + ) |
| 1320 | + |
| 1321 | + if model is None: |
| 1322 | + raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") |
| 1323 | + |
| 1324 | + logger.warning( |
| 1325 | + "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !" |
| 1326 | + ) |
| 1327 | + |
| 1328 | + all_linear = ( |
| 1329 | + isinstance(args.optim_target_modules, str) |
| 1330 | + and args.optim_target_modules.replace("_", "-") == "all-linear" |
| 1331 | + ) |
| 1332 | + |
| 1333 | + galore_params = [] |
| 1334 | + galore_params_names = [] |
| 1335 | + for module_name, module in model.named_modules(): |
| 1336 | + target_module_exists, is_regex = check_target_module_exists( |
| 1337 | + args.optim_target_modules, module_name, return_is_regex=True |
| 1338 | + ) |
| 1339 | + |
| 1340 | + if not isinstance(module, nn.Linear): |
| 1341 | + # Warn in case we match but it's not a linear layer |
| 1342 | + if target_module_exists and not is_regex: |
| 1343 | + logger.warning( |
| 1344 | + f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!" |
| 1345 | + ) |
| 1346 | + |
| 1347 | + continue |
| 1348 | + |
| 1349 | + if not target_module_exists and not all_linear: |
| 1350 | + continue |
| 1351 | + |
| 1352 | + galore_params.append(module.weight) |
| 1353 | + galore_params_names.append(module_name + ".weight") |
| 1354 | + |
| 1355 | + if len(galore_params) == 0: |
| 1356 | + raise ValueError( |
| 1357 | + f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." |
| 1358 | + ) |
| 1359 | + |
| 1360 | + non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] |
| 1361 | + |
| 1362 | + # The default args are from the official repository: https://github.com/VITA-Group/Q-GaLore |
| 1363 | + galore_optim_kwargs = { |
| 1364 | + "rank": int(optim_args.pop("rank", 256)), |
| 1365 | + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), |
| 1366 | + "scale": float(optim_args.pop("scale", 0.25)), |
| 1367 | + "proj_type": optim_args.pop("proj_type", "std"), |
| 1368 | + "quant": optim_args.pop("quant", True), |
| 1369 | + "quant_n_bit": optim_args.pop("quant_n_bit", 4), |
| 1370 | + "quant_group_size": optim_args.pop("quant_group_size", 256), |
| 1371 | + "cos_threshold": optim_args.pop("cos_threshold", 0.4), |
| 1372 | + "gamma_proj": optim_args.pop("gamma_proj", 2), |
| 1373 | + "queue_size": optim_args.pop("queue_size", 5), |
| 1374 | + } |
| 1375 | + |
| 1376 | + param_groups = [ |
| 1377 | + {"params": non_galore_params}, |
| 1378 | + {"params": galore_params, **galore_optim_kwargs}, |
| 1379 | + ] |
| 1380 | + |
| 1381 | + if is_layerwise: |
| 1382 | + # For layer-wise optimizers, the optimization step is done through post accumulation |
| 1383 | + # gradient hooks. The trick is to first attach these hooks to the model parameters then |
| 1384 | + # create a dummy optimizer that will perform no-ops in the Trainer. |
| 1385 | + # See the original implementation or the nice implementation from @hiyouga |
| 1386 | + # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba |
| 1387 | + if args.gradient_accumulation_steps != 1: |
| 1388 | + raise ValueError("Layerwise QGaLoRE optimizer do not support gradient accumulation !") |
| 1389 | + |
| 1390 | + optimizer_dict = {} |
| 1391 | + for param in non_galore_params: |
| 1392 | + if param.requires_grad: |
| 1393 | + param_groups = [{"params": [param]}] |
| 1394 | + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) |
| 1395 | + # TODO: in the original repo, they multiply update_proj_gap param by 2, to check |
| 1396 | + for param in galore_params: |
| 1397 | + param_groups = [{"params": [param], **galore_optim_kwargs}] |
| 1398 | + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) |
| 1399 | + |
| 1400 | + def optimizer_hook(param): |
| 1401 | + if (not hasattr(param, "float_grad")) and param.grad is None: |
| 1402 | + return |
| 1403 | + optimizer_dict[param].step() |
| 1404 | + optimizer_dict[param].zero_grad() |
| 1405 | + |
| 1406 | + id_galore_params = [id(p) for p in galore_params] |
| 1407 | + |
| 1408 | + # TODO: strange, we are not applying on every param here compared to galore |
| 1409 | + for param in model.parameters(): |
| 1410 | + if id(param) in id_galore_params or param.requires_grad: |
| 1411 | + setattr(param, "backward_hook", optimizer_hook) |
| 1412 | + |
| 1413 | + optimizer_cls = LayerWiseDummyOptimizer |
| 1414 | + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) |
| 1415 | + |
| 1416 | + optimizer_kwargs.update({"params": param_groups}) |
| 1417 | + |
1291 | 1418 | elif args.optim in [
|
1292 | 1419 | OptimizerNames.GALORE_ADAMW,
|
1293 | 1420 | OptimizerNames.GALORE_ADAMW_8BIT,
|
|
0 commit comments