-
Notifications
You must be signed in to change notification settings - Fork 346
Enhance Autoround to support multiple cards tuning #2157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 27 commits
bc97d48
19ab4f2
9ba113c
646982a
1050335
50e6682
d139071
a0affbd
17ba9f5
cd943cd
8338ed5
56515af
ad6c1c0
09a72c0
af112bd
ec98118
c5eae60
2d482fc
17b7e45
7a9b3cd
4f45b17
0fac601
0f7a990
58ef017
c9ea99c
d2a7c92
d48c3d6
993a68e
fa8cdcc
c17e923
1092cde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| from auto_round.calib_dataset import get_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.autoround import AutoRoundModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
|
|
||
| # Select model and load it. | ||
| model_id = "Qwen/Qwen3-235B-A22B/" | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| # Select calibration dataset. | ||
| NUM_CALIBRATION_SAMPLES = 128 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
| ITERS = 200 | ||
| # Get aligned calibration dataset. | ||
|
|
||
| ds = get_dataset( | ||
| tokenizer=tokenizer, | ||
| seqlen=MAX_SEQUENCE_LENGTH, | ||
| nsamples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to 4 bit with AutoRound with a group size 128 | ||
| # * For `Qwen/Qwen3-235B-A22B`, it requires more about 300 GB memory to run tuning with default settings. | ||
| recipe = AutoRoundModifier( | ||
| targets="Linear", | ||
| scheme="W4A16", | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*mlp.gate$", | ||
| ], | ||
| iters=ITERS, | ||
| enable_torch_compile=False, | ||
| device_map="0,1,2,3", # Use 4 A100 GPUs | ||
| ) | ||
|
|
||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| shuffle_calibration_samples=False, | ||
| ) | ||
|
|
||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound" | ||
| print(f"save to {SAVE_DIR}") | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
|
|
||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================\n\n") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| from contextlib import contextmanager | ||
| from typing import Dict, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from accelerate.hooks import add_hook_to_module, remove_hook_from_submodules | ||
| from auto_round import AutoRound | ||
| from auto_round.schemes import QuantizationScheme as ARQuantizationScheme | ||
| from compressed_tensors.quantization import ( | ||
|
|
@@ -54,6 +57,33 @@ def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper: | |
| return wrapped_model | ||
|
|
||
|
|
||
| @contextmanager | ||
| def suspend_accelerate_hooks(model: nn.Module): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just fyi, we are refactoring our usage of accelerate hooks for offloading. You can follow some of that in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I noticed that change as well. We can adapt to it once it’s ready.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we have anything posted on the decision to move away from |
||
| """ | ||
| Temporarily suspend Accelerate hooks from a model. | ||
|
|
||
| This context manager detaches all Accelerate hooks (used for device offloading, | ||
| dtype casting, etc.) from the model, allowing Autoround to operate without interference. | ||
| On exit, the model is restored to its original device and all hooks are re-attached. | ||
| """ | ||
| saved_hooks = {} | ||
| original_device = next(model.parameters()).device | ||
| for name, module in model.named_modules(): | ||
| if hasattr(module, "_hf_hook"): | ||
| saved_hooks[name] = module._hf_hook | ||
|
|
||
| remove_hook_from_submodules(model) | ||
| try: | ||
| yield | ||
| finally: | ||
| remove_hook_from_submodules(model) | ||
| model.to(original_device) | ||
yiliu30 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for name, module in model.named_modules(): | ||
| if name in saved_hooks: | ||
| logger.info("Restoring Accelerate hook for module: {}", name) | ||
| add_hook_to_module(module, saved_hooks[name], append=True) | ||
|
|
||
|
|
||
| class AutoRoundModifier(Modifier, QuantizationMixin): | ||
| """ | ||
| Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf. | ||
|
|
@@ -110,6 +140,10 @@ class AutoRoundModifier(Modifier, QuantizationMixin): | |
| iters: int = 200 | ||
| enable_torch_compile: bool = True | ||
| batch_size: int = 8 | ||
| # optional device map for layer dispatch during tuning | ||
| # examples: "0,1" for cuda:0,cuda:1; "auto" to use all available GPUs | ||
| # when None, no dispatching and the model stays on its current device | ||
| device_map: Optional[str] = None | ||
|
|
||
| # private variables | ||
| _all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict) | ||
|
|
@@ -215,15 +249,20 @@ def apply_autoround(self, state, subgraph): | |
| wrapped_model = _wrap_decoding_layer(decoding_layer) | ||
| wrapped_model.name_or_path = state.model.name_or_path | ||
|
|
||
| with torch.enable_grad(), align_module_device(decoding_layer): | ||
| with torch.enable_grad(), align_module_device( | ||
| decoding_layer | ||
| ), suspend_accelerate_hooks(wrapped_model): | ||
| ar_quant_scheme = self._mapping_config_to_autoround() | ||
| fp_layers = self.get_unquantized_layer_names(decoding_layer) | ||
| ar = AutoRound( | ||
| model=wrapped_model, | ||
| tokenizer="", | ||
| scheme=ar_quant_scheme, | ||
| iters=self.iters, | ||
| enable_torch_compile=self.enable_torch_compile, | ||
| batch_size=self.batch_size, | ||
| device_map=self.device_map, | ||
|
||
| fp_layers=",".join(fp_layers) if fp_layers else "", | ||
| ) | ||
| # TODO: configure layer-wise config based on self.resolved_config | ||
| ar.configure_layer_config(enable_gguf_official_mixed=False) | ||
|
|
@@ -232,21 +271,25 @@ def apply_autoround(self, state, subgraph): | |
| device = first_param.device | ||
| cur_inputs = self._all_module_input[decoding_layer._tmp_name] | ||
| decoding_layer.tuning_device = device | ||
| # Leave offload for LLMC to handle if `device_map` is not set | ||
| auto_offload = False | ||
| if self.device_map is not None: | ||
| # When device_map is set, we move decoding layer to CPU first, | ||
| # then the submodules will be re-dispatched by AutoRound. | ||
| decoding_layer.to("cpu") | ||
| auto_offload = True | ||
|
|
||
| q_input, _ = ar.quantize_block( | ||
| block=decoding_layer, | ||
| inputs=cur_inputs, | ||
| q_input=self._q_input, | ||
| device=str(device), | ||
| # Leave offload for LLMC | ||
| auto_offload=False, | ||
| auto_offload=auto_offload, | ||
| ) | ||
| self._q_input = q_input | ||
| # Update offload parameters and remove temporary attributes | ||
| for _, module in decoding_layer.named_modules(): | ||
| if hasattr(module, "weight_scale") and hasattr( | ||
| module, "weight_zero_point" | ||
| ): | ||
| if hasattr(module, "scale") and hasattr(module, "weight_zero_point"): | ||
| # Note: The model's weight is already q-dq in-place by auto-round. | ||
| weight_scale = module.scale | ||
| del module.scale | ||
|
|
@@ -278,6 +321,17 @@ def on_finalize(self, state: State, **kwargs) -> bool: | |
|
|
||
| return True | ||
|
|
||
| def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> List[str]: | ||
yiliu30 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| unquantized_layers = [] | ||
|
|
||
| for name, module in wrapped_model.named_modules(): | ||
| if ( | ||
| module.__class__.__name__ in self.resolved_targets | ||
| and getattr(module, "quantization_scheme", None) is None | ||
| ): | ||
| unquantized_layers.append(name) | ||
| return unquantized_layers | ||
|
|
||
| def _add_temporary_names(self, model: torch.nn.Module): | ||
| for name, mod in model.named_modules(): | ||
| mod._tmp_name = name | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think the name here is confusing, isn't device_map usually a different format like a dict that maps layer name to device id? If
device_map="0,1,2,3"is valid in transformers, we can leave as is, otherwisedevice_idsmay be a better nameUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the
device_mapis used in Transformers, and we follow a similar approach. Please refer to:https://huggingface.co/docs/accelerate/en/usage_guides/big_modeling#accelerate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've only ever seen device_map be a string like "auto" or "sequential", or a dictionary mapping each module name to each device_id, like
What does it mean if
device_map="0,1,2,3"? Is that like auto but only with the first 4 devices?Reference: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#designing-a-device-map
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you're right. Updated to
device_ids!