Skip to content

Commit 9d6abf9

Browse files
jiqing-fengMekkCyberSunMarc
authored
enable torchao quantization on CPU (#36146)
* enable torchao quantization on CPU Signed-off-by: jiqing-feng <[email protected]> * fix int4 Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * enable CPU torchao tests Signed-off-by: jiqing-feng <[email protected]> * fix cuda tests Signed-off-by: jiqing-feng <[email protected]> * fix cpu tests Signed-off-by: jiqing-feng <[email protected]> * update tests Signed-off-by: jiqing-feng <[email protected]> * fix style Signed-off-by: jiqing-feng <[email protected]> * fix cuda tests Signed-off-by: jiqing-feng <[email protected]> * fix torchao available Signed-off-by: jiqing-feng <[email protected]> * fix torchao available Signed-off-by: jiqing-feng <[email protected]> * fix torchao config cannot convert to json * fix docs Signed-off-by: jiqing-feng <[email protected]> * rm to_dict to rebase Signed-off-by: jiqing-feng <[email protected]> * limited torchao version for CPU Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * fix skip Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * Update src/transformers/testing_utils.py Co-authored-by: Marc Sun <[email protected]> * fix cpu test Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]> Co-authored-by: Marc Sun <[email protected]>
1 parent 401543a commit 9d6abf9

File tree

5 files changed

+124
-69
lines changed

5 files changed

+124
-69
lines changed

docs/source/en/quantization/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Use the table below to help you decide which quantization method to use.
5959
| [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
6060
| [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
6161
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
62-
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
62+
| [torchao](./torchao.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
6363
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
6464
| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
6565
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |

docs/source/en/quantization/torchao.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ pip install --upgrade torch torchao transformers
2222

2323
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
2424

25+
2526
## Manually Choose Quantization Types and Settings
2627

2728
`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
29+
If you want to run the following codes on CPU even with GPU available, just change `device_map="cpu"` and `quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=Int4CPULayout())` where `layout` comes from `from torchao.dtypes import Int4CPULayout` which is only available from torchao 0.8.0 and higher.
2830

2931
Users can manually specify the quantization types and settings they want to use:
3032

@@ -40,7 +42,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="
4042

4143
tokenizer = AutoTokenizer.from_pretrained(model_name)
4244
input_text = "What are we having for dinner?"
43-
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
45+
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
4446

4547
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
4648
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
@@ -59,7 +61,7 @@ def benchmark_fn(func: Callable, *args, **kwargs) -> float:
5961
MAX_NEW_TOKENS = 1000
6062
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
6163

62-
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
64+
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
6365
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
6466
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
6567

@@ -122,7 +124,7 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False)
122124

123125
# load quantized model
124126
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
125-
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda")
127+
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="auto")
126128

127129

128130
# confirm the speedup

src/transformers/testing_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import huggingface_hub.utils
4646
import urllib3
4747
from huggingface_hub import delete_repo
48+
from packaging import version
4849

4950
from transformers import logging as transformers_logging
5051

@@ -963,6 +964,18 @@ def require_torchao(test_case):
963964
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
964965

965966

967+
def require_torchao_version_greater_or_equal(torchao_version):
968+
def decorator(test_case):
969+
correct_torchao_version = is_torchao_available() and version.parse(
970+
version.parse(importlib.metadata.version("torchao")).base_version
971+
) >= version.parse(torchao_version)
972+
return unittest.skipUnless(
973+
correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
974+
)(test_case)
975+
976+
return decorator
977+
978+
966979
def require_torch_tensorrt_fx(test_case):
967980
"""Decorator marking a test that requires Torch-TensorRT FX"""
968981
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)

src/transformers/utils/quantization_config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1558,7 +1558,17 @@ def _get_torchao_quant_type_to_method(self):
15581558

15591559
def get_apply_tensor_subclass(self):
15601560
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
1561-
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
1561+
quant_type_kwargs = self.quant_type_kwargs.copy()
1562+
if (
1563+
not torch.cuda.is_available()
1564+
and is_torchao_available()
1565+
and self.quant_type == "int4_weight_only"
1566+
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
1567+
):
1568+
from torchao.dtypes import Int4CPULayout
1569+
1570+
quant_type_kwargs["layout"] = Int4CPULayout()
1571+
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)
15621572

15631573
def __repr__(self):
15641574
config_dict = self.to_dict()

0 commit comments

Comments
 (0)