-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Add option for ao base configs #36526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,18 +20,95 @@ Install torchao with the following command. | |
| pip install --upgrade torch torchao transformers | ||
| ``` | ||
|
|
||
| torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights. | ||
| torchao supports many quantization types for different data types (int4, float8, weight only, etc.). | ||
| Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations. | ||
| And full access to the techniques offered in the torchao library. | ||
|
|
||
| You can manually choose the quantization types and settings or automatically select the quantization types. | ||
|
|
||
| <hfoptions id="torchao"> | ||
| <hfoption id="manual"> | ||
|
|
||
|
|
||
| Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. | ||
|
|
||
| > [!TIP] | ||
| > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. | ||
|
|
||
| In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers: | ||
|
|
||
| ```py | ||
| import torch | ||
| from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | ||
| from torchao.quantization import Int4WeightOnlyConfig | ||
|
|
||
| # Using AOBaseConfig instance (torchao >= 0.10.0) | ||
| quant_config = Int4WeightOnlyConfig(group_size=128) | ||
| quantization_config = TorchAoConfig(quant_type=quant_config) | ||
|
|
||
| # Load and quantize the model | ||
| quantized_model = AutoModelForCausalLM.from_pretrained( | ||
| "meta-llama/Meta-Llama-3-8B", | ||
| torch_dtype="auto", | ||
| device_map="auto", | ||
| quantization_config=quantization_config | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") | ||
| input_text = "What are we having for dinner?" | ||
| input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") | ||
|
|
||
| # auto-compile the quantized model with `cache_implementation="static"` to get speed up | ||
| output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") | ||
| print(tokenizer.decode(output[0], skip_special_tokens=True)) | ||
| ``` | ||
|
|
||
| ## Available Quantization Schemes | ||
|
|
||
| TorchAO provides a variety of quantization configurations: | ||
|
|
||
| - `Int4WeightOnlyConfig` | ||
| - `Int8WeightOnlyConfig` | ||
| - `Int8DynamicActivationInt8WeightConfig` | ||
| - `Float8WeightOnlyConfig` | ||
|
|
||
| Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures. | ||
|
|
||
| For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py). | ||
|
|
||
| > **⚠️ DEPRECATION WARNING** | ||
| > | ||
| > Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember it's fine for transformers to always depend on the most recent torchao versions
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If possible, we would like to support older version of torchao too but I feel like for now, it's fine for the user to download the most recent version of torchao. |
||
| > | ||
| > Please use the new `AOBaseConfig`-based approach instead: | ||
| > | ||
| > ```python | ||
| > # Old way (deprecated) | ||
| > quantization_config = TorchAoConfig("int4_weight_only", group_size=128) | ||
| > | ||
| > # New way (recommended) | ||
| > from torchao.quantization import Int4WeightOnlyConfig | ||
| > quant_config = Int4WeightOnlyConfig(group_size=128) | ||
| > quantization_config = TorchAoConfig(quant_type=quant_config) | ||
| > ``` | ||
| > | ||
| > The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao. | ||
| > | ||
| > ## Migration Guide | ||
| > | ||
| > Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents: | ||
| > | ||
| > | Old String API | New `AOBaseConfig` API | | ||
| > |----------------|------------------------| | ||
| > | `"int4_weight_only"` | `Int4WeightOnlyConfig()` | | ||
| > | `"int8_weight_only"` | `Int8WeightOnlyConfig()` | | ||
| > | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` | | ||
| > | ||
| > All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`). | ||
|
|
||
|
|
||
| Below is the API for for torchao < `0.9.0` | ||
|
|
||
| ```py | ||
| import torch | ||
| from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
|
@@ -78,7 +155,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke | |
|
|
||
| The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer. | ||
|
|
||
| Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes. | ||
| Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes. | ||
|
|
||
| > [!TIP] | ||
| > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. | ||
|
|
@@ -131,7 +208,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke | |
|
|
||
| ## Serialization | ||
|
|
||
| torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco. | ||
| torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao. | ||
|
|
||
| To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals). | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import importlib | ||
| import re | ||
| import types | ||
| from typing import TYPE_CHECKING, Optional, Union | ||
|
|
||
|
|
@@ -27,6 +28,7 @@ | |
| from typing import Any, Dict, List | ||
|
|
||
| from ..utils import is_torch_available, is_torchao_available, logging | ||
| from ..utils.quantization_config import TorchAoConfig | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
|
|
@@ -36,6 +38,21 @@ | |
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| def fuzzy_match_size(config_name: str) -> Optional[str]: | ||
| """ | ||
| Extract the size digit from strings like "4weight", "8weight". | ||
| Returns the digit as an integer if found, otherwise None. | ||
| """ | ||
| config_name = config_name.lower() | ||
|
|
||
| str_match = re.search(r"(\d)weight", config_name) | ||
|
|
||
| if str_match: | ||
| return str_match.group(1) | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| # Finds the parent of a node module named "name" | ||
| def find_parent(model, name): | ||
| module_tree = name.split(".")[:-1] | ||
|
|
@@ -121,10 +138,28 @@ def update_torch_dtype(self, torch_dtype): | |
| torch_dtype = torch.float32 | ||
| return torch_dtype | ||
|
|
||
| def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": | ||
| def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
| if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): | ||
| from accelerate.utils import CustomDtype | ||
|
|
||
| # Import AOBaseConfig directly since we know we have the right version | ||
| if self.quantization_config._get_ao_version() >= version.Version("0.10.0"): | ||
| from torchao.core.config import AOBaseConfig | ||
|
|
||
| quant_type = self.quantization_config.quant_type | ||
| if isinstance(quant_type, AOBaseConfig): | ||
| # Extract size digit using fuzzy match on the class name | ||
| config_name = quant_type.__class__.__name__ | ||
| size_digit = fuzzy_match_size(config_name) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems a bit fragile? e.g. what would it look like for mx, fp4 etc. |
||
|
|
||
| # Map the extracted digit to appropriate dtype | ||
| if size_digit == "4": | ||
| return CustomDtype.INT4 | ||
| else: | ||
| # Default to int8 | ||
| return torch.int8 | ||
|
Comment on lines
+156
to
+160
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if these are really needed, cc @SunMarc when are these used?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are used when calculating the appropriate device_map (e.g. to know how to dispatch the layers in the different gpus). This is needed in torchao case as the model architecture is not changed prior to calculating the device_map. |
||
|
|
||
| # Original mapping for non-AOBaseConfig types | ||
| map_to_target_dtype = { | ||
| "int4_weight_only": CustomDtype.INT4, | ||
| "int8_weight_only": torch.int8, | ||
|
|
@@ -194,14 +229,14 @@ def create_quantized_param( | |
| from torchao.quantization import quantize_ | ||
|
|
||
| module, tensor_name = get_module_from_name(model, param_name) | ||
|
|
||
| if self.pre_quantized: | ||
| module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) | ||
| if isinstance(module, nn.Linear): | ||
| module.extra_repr = types.MethodType(_linear_extra_repr, module) | ||
| else: | ||
| assert isinstance(self.quantization_config, TorchAoConfig) | ||
| module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) | ||
| quantize_(module, self.quantization_config.get_apply_tensor_subclass()) | ||
| quantize_(module, self.quantization_config.get_quantize_config()) | ||
|
|
||
| def _process_model_after_weight_loading(self, model, **kwargs): | ||
| """No process required for torchao quantized model""" | ||
|
|
@@ -216,7 +251,7 @@ def _process_model_after_weight_loading(self, model, **kwargs): | |
| return model | ||
| return | ||
|
|
||
| def is_serializable(self, safe_serialization=None): | ||
| def is_serializable(self, safe_serialization=None) -> bool: | ||
| if safe_serialization: | ||
| logger.warning( | ||
| "torchao quantized model does not support safe serialization, " | ||
|
|
@@ -237,7 +272,7 @@ def is_serializable(self, safe_serialization=None): | |
| return _is_torchao_serializable | ||
|
|
||
| @property | ||
| def is_trainable(self): | ||
| def is_trainable(self) -> bool: | ||
| supported_quant_types_for_training = [ | ||
| "int8_weight_only", | ||
| "int8_dynamic_activation_int8_weight", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.