Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,95 @@ Install torchao with the following command.
pip install --upgrade torch torchao transformers
```

torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights.
torchao supports many quantization types for different data types (int4, float8, weight only, etc.).
Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations.
And full access to the techniques offered in the torchao library.

You can manually choose the quantization types and settings or automatically select the quantization types.

<hfoptions id="torchao">
<hfoption id="manual">


Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.

> [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.

In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers:

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig

# Using AOBaseConfig instance (torchao >= 0.10.0)
quant_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

## Available Quantization Schemes

TorchAO provides a variety of quantization configurations:

- `Int4WeightOnlyConfig`
- `Int8WeightOnlyConfig`
- `Int8DynamicActivationInt8WeightConfig`
- `Float8WeightOnlyConfig`

Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures.

For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py).

> **⚠️ DEPRECATION WARNING**
>
> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it's fine for transformers to always depend on the most recent torchao versions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, we would like to support older version of torchao too but I feel like for now, it's fine for the user to download the most recent version of torchao.

>
> Please use the new `AOBaseConfig`-based approach instead:
>
> ```python
> # Old way (deprecated)
> quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
>
> # New way (recommended)
> from torchao.quantization import Int4WeightOnlyConfig
> quant_config = Int4WeightOnlyConfig(group_size=128)
> quantization_config = TorchAoConfig(quant_type=quant_config)
> ```
>
> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.
>
> ## Migration Guide
>
> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents:
>
> | Old String API | New `AOBaseConfig` API |
> |----------------|------------------------|
> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` |
> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` |
> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` |
>
> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`).


Below is the API for for torchao < `0.9.0`

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -78,7 +155,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke

The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer.

Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.

> [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
Expand Down Expand Up @@ -131,7 +208,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke

## Serialization

torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco.
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.

To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).

Expand Down
45 changes: 40 additions & 5 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import re
import types
from typing import TYPE_CHECKING, Optional, Union

Expand All @@ -27,6 +28,7 @@
from typing import Any, Dict, List

from ..utils import is_torch_available, is_torchao_available, logging
from ..utils.quantization_config import TorchAoConfig


if is_torch_available():
Expand All @@ -36,6 +38,21 @@
logger = logging.get_logger(__name__)


def fuzzy_match_size(config_name: str) -> Optional[str]:
"""
Extract the size digit from strings like "4weight", "8weight".
Returns the digit as an integer if found, otherwise None.
"""
config_name = config_name.lower()

str_match = re.search(r"(\d)weight", config_name)

if str_match:
return str_match.group(1)

return None


# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
Expand Down Expand Up @@ -121,10 +138,28 @@ def update_torch_dtype(self, torch_dtype):
torch_dtype = torch.float32
return torch_dtype

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
from accelerate.utils import CustomDtype

# Import AOBaseConfig directly since we know we have the right version
if self.quantization_config._get_ao_version() >= version.Version("0.10.0"):
from torchao.core.config import AOBaseConfig

quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit fragile? e.g. what would it look like for mx, fp4 etc.


# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
Comment on lines +156 to +160
Copy link
Contributor

@jerryzh168 jerryzh168 Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if these are really needed, cc @SunMarc when are these used?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are used when calculating the appropriate device_map (e.g. to know how to dispatch the layers in the different gpus). This is needed in torchao case as the model architecture is not changed prior to calculating the device_map.


# Original mapping for non-AOBaseConfig types
map_to_target_dtype = {
"int4_weight_only": CustomDtype.INT4,
"int8_weight_only": torch.int8,
Expand Down Expand Up @@ -194,14 +229,14 @@ def create_quantized_param(
from torchao.quantization import quantize_

module, tensor_name = get_module_from_name(model, param_name)

if self.pre_quantized:
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
assert isinstance(self.quantization_config, TorchAoConfig)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
quantize_(module, self.quantization_config.get_quantize_config())

def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
Expand All @@ -216,7 +251,7 @@ def _process_model_after_weight_loading(self, model, **kwargs):
return model
return

def is_serializable(self, safe_serialization=None):
def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization:
logger.warning(
"torchao quantized model does not support safe serialization, "
Expand All @@ -237,7 +272,7 @@ def is_serializable(self, safe_serialization=None):
return _is_torchao_serializable

@property
def is_trainable(self):
def is_trainable(self) -> bool:
supported_quant_types_for_training = [
"int8_weight_only",
"int8_dynamic_activation_int8_weight",
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1"
VPTQ_MIN_VERSION = "0.0.4"
TORCHAO_MIN_VERSION = "0.4.0"


_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
Expand Down Expand Up @@ -191,7 +192,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchao_available = _is_package_available("torchao")
_torchao_available, _torchao_version = _is_package_available("torchao", return_version=True)
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
_mlx_available = _is_package_available("mlx")
Expand Down Expand Up @@ -1277,8 +1278,8 @@ def is_torchaudio_available():
return _torchaudio_available


def is_torchao_available():
return _torchao_available
def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION):
return _torchao_available and version.parse(_torchao_version) >= version.parse(min_version)


def is_speech_available():
Expand Down
Loading