Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
936b206
enable torchao quantization on CPU
jiqing-feng Feb 12, 2025
4759045
fix int4
jiqing-feng Feb 12, 2025
5e51a1c
fix format
jiqing-feng Feb 12, 2025
6b3c076
enable CPU torchao tests
jiqing-feng Feb 12, 2025
2bf0ba2
fix cuda tests
jiqing-feng Feb 12, 2025
36c6534
fix cpu tests
jiqing-feng Feb 12, 2025
872c778
update tests
jiqing-feng Feb 13, 2025
76badb1
fix style
jiqing-feng Feb 13, 2025
c964c6f
fix cuda tests
jiqing-feng Feb 13, 2025
92b3ff1
Merge branch 'main' into torchao
jiqing-feng Feb 13, 2025
fcf3e9e
fix torchao available
jiqing-feng Feb 13, 2025
a871b35
fix torchao available
jiqing-feng Feb 13, 2025
65b7de3
fix torchao config cannot convert to json
jiqing-feng Feb 13, 2025
6847b7c
Merge branch 'main' into torchao
jiqing-feng Feb 14, 2025
33da778
fix docs
jiqing-feng Feb 14, 2025
8b9b6b1
Merge branch 'main' into torchao
jiqing-feng Feb 17, 2025
50d48c2
Merge branch 'main' into torchao
jiqing-feng Feb 18, 2025
f5c2c8d
Merge branch 'main' into torchao
jiqing-feng Feb 19, 2025
e1bdbd7
rm to_dict to rebase
jiqing-feng Feb 19, 2025
a880c2c
Merge branch 'main' into torchao
MekkCyber Feb 19, 2025
49015bf
limited torchao version for CPU
jiqing-feng Feb 20, 2025
81897c4
Merge branch 'main' into torchao
jiqing-feng Feb 20, 2025
135bbab
fix format
jiqing-feng Feb 20, 2025
443b1cf
Merge branch 'main' into torchao
jiqing-feng Feb 21, 2025
248e065
fix skip
jiqing-feng Feb 21, 2025
a71d8b9
fix format
jiqing-feng Feb 21, 2025
9b3053a
Merge branch 'main' into torchao
SunMarc Feb 21, 2025
e2fef70
Update src/transformers/testing_utils.py
jiqing-feng Feb 24, 2025
66b5751
Merge branch 'main' into torchao
jiqing-feng Feb 24, 2025
d356bf6
Merge branch 'main' into torchao
jiqing-feng Feb 25, 2025
9d529ca
fix cpu test
jiqing-feng Feb 25, 2025
a633f27
fix format
jiqing-feng Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Use the table below to help you decide which quantization method to use.
| [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [torchao](./torchao.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
Expand Down
8 changes: 5 additions & 3 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ pip install --upgrade torch torchao transformers

By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.


## Manually Choose Quantization Types and Settings

`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
If you want to run the following codes on CPU even with GPU available, just change `device_map="cpu"` and `quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=Int4CPULayout())` where `layout` comes from `from torchao.dtypes import Int4CPULayout` which is only available from torchao 0.8.0 and higher.

Users can manually specify the quantization types and settings they want to use:

Expand All @@ -40,7 +42,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speedup
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
Expand All @@ -59,7 +61,7 @@ def benchmark_fn(func: Callable, *args, **kwargs) -> float:
MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

Expand Down Expand Up @@ -122,7 +124,7 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False)

# load quantized model
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda")
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="auto")


# confirm the speedup
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import huggingface_hub.utils
import urllib3
from huggingface_hub import delete_repo
from packaging import version

from transformers import logging as transformers_logging

Expand Down Expand Up @@ -963,6 +964,18 @@ def require_torchao(test_case):
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)


def require_torchao_version_greater_or_equal(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
)(test_case)

return decorator


def require_torch_tensorrt_fx(test_case):
"""Decorator marking a test that requires Torch-TensorRT FX"""
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,7 +1558,17 @@ def _get_torchao_quant_type_to_method(self):

def get_apply_tensor_subclass(self):
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
):
from torchao.dtypes import Int4CPULayout

quant_type_kwargs["layout"] = Int4CPULayout()
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)

def __repr__(self):
config_dict = self.to_dict()
Expand Down
Loading