diff --git a/README.md b/README.md index a77cd038da..736915463f 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation. Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py) ```python -from torchao.quantization.quant_api import quantize -m = quantize(m, "int4wo") +from torchao.quantization.quant_api import quantize, int4_weight_only +m = quantize(m, int4_weight_only()) ``` Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 76b07f2e7a..05e84d5006 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -2,7 +2,7 @@ TestCase, run_tests, ) -from torchao.quantization.quant_api import int4wo +from torchao.quantization.quant_api import int4_weight_only import torch import unittest @@ -12,8 +12,8 @@ class TestAffineQuantized(TestCase): def test_tensor_core_layout_transpose(self): t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda") shape = t.shape - apply_int4wo_quant = int4wo(groupsize=32) - aqt = apply_int4wo_quant(t) + apply_int4_weight_only_quant = int4_weight_only(group_size=32) + aqt = apply_int4_weight_only_quant(t) aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index a1cf2c4368..b4fbcb152a 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -20,9 +20,9 @@ DynamicallyPerAxisQuantizedLinear, ) from torchao.quantization.quant_api import ( - int4wo, - int8wo, - int8da_int8w, + int4_weight_only, + int8_weight_only, + int8_dynamic_activation_int8_weight, quantize, _replace_with_custom_fn_if_matches_filter, ) @@ -98,21 +98,21 @@ def _int8wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8wo()) + quantize(mod, int8_weight_only()) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) def _int8da_int8w_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8da_int8w()) + quantize(mod, int8_dynamic_activation_int8_weight()) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4wo()) + quantize(mod, int4_weight_only()) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod) @@ -832,7 +832,10 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): def api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4wo(**kwargs)) + kwargs_copy = kwargs.copy() + kwargs_copy["group_size"] = groupsize + del kwargs_copy["groupsize"] + quantize(mod, int4_weight_only(**kwargs_copy)) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod, **kwargs) @@ -853,7 +856,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - quantize(m, int8da_int8w()) + quantize(m, int8_dynamic_activation_int8_weight()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -1463,7 +1466,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype): api(model) size2 = torchao.utils.get_model_size_in_bytes(model) self.assertTrue(size2 < size) - + diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index be8ef5795f..b22a157568 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -36,10 +36,10 @@ _replace_with_custom_fn_if_matches_filter, Quantizer, TwoStepQuantizer, - int8da_int4w, - int4wo, - int8wo, - int8da_int8w, + int8_dynamic_activation_int4_weight, + int4_weight_only, + int8_weight_only, + int8_dynamic_activation_int8_weight, ) from torchao.utils import ( TORCH_VERSION_AFTER_2_3, @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - quantize(model, int8da_int8w()) + quantize(model, int8_dynamic_activation_int8_weight()) return model class ToyLinearModel(torch.nn.Module): @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - m = quantize(m, int8da_int8w()) + m = quantize(m, int8_dynamic_activation_int8_weight()) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self): ) m = ToyLinearModel().eval().cpu() def api(model): - model = quantize(model, int8wo()) + model = quantize(model, int8_weight_only()) unwrap_tensor_subclass(model) api(m) @@ -335,7 +335,7 @@ def test_8da4w_quantizer_eval(self): ) @unittest.skip("skipping until we get checkpoints for gpt-fast") - def test_gptq_quantizer_int4wo(self): + def test_gptq_quantizer_int4_weight_only(self): from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer from torchao._models._eval import InputRecorder, TransformerEvalWrapper torchao._models.llama.model.use_index_put_for_kv_cache = True @@ -397,7 +397,7 @@ def test_gptq_quantizer_int4wo(self): ) @unittest.skip("skipping until we get checkpoints for gpt-fast") - def test_quantizer_int4wo(self): + def test_quantizer_int4_weight_only(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 @@ -499,11 +499,11 @@ def test_eval_wrapper_llama3(self): # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") def test_quantized_tensor_subclass_8da4w(self): - groupsize = 32 + group_size = 32 m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, int8da_int4w(groupsize=groupsize)) + m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -514,7 +514,7 @@ def test_quantized_tensor_subclass_8da4w(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear - quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize) + quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) m_copy = quantizer.quantize(m_copy) assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) @@ -531,13 +531,13 @@ def test_quantized_tensor_subclass_int4(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") - groupsize = 32 - m = quantize(m, int4wo(groupsize=groupsize)) + group_size = 32 + m = quantize(m, int4_weight_only(group_size=group_size)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) # reference - _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize) + _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size) res = m(*example_inputs) ref = m_copy(*example_inputs) @@ -552,7 +552,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - m = quantize(m, int8wo()) + m = quantize(m, int8_weight_only()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -575,7 +575,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - m = quantize(m, int8da_int8w()) + m = quantize(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -602,21 +602,6 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # make sure it compiles torch._export.aot_compile(m_unwrapped, example_inputs) - def test_register_apply_tensor_subclass(self): - from torchao import register_apply_tensor_subclass - def apply_my_dtype(weight): - return weight * 2 - - m = ToyLinearModel().eval() - example_inputs = m.example_inputs() - with self.assertRaisesRegex(ValueError, "not supported"): - quantize(m, "my_dtype") - - register_apply_tensor_subclass("my_dtype", apply_my_dtype) - # make sure it runs - quantize(m, "my_dtype") - m(*example_inputs) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load(self): @@ -624,7 +609,7 @@ def test_quantized_tensor_subclass_save_load(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16) - m = quantize(m, "int8_weight_only") + m = quantize(m, int8_weight_only()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) diff --git a/torchao/__init__.py b/torchao/__init__.py index 5e043026ac..0d252f5668 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -31,7 +31,6 @@ from torchao.quantization import ( autoquant, quantize, - register_apply_tensor_subclass, ) from . import dtypes @@ -39,5 +38,4 @@ "dtypes", "autoquant", "quantize", - "register_apply_tensor_subclass", ] diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index c5264048c7..4d61f97ac9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,7 +1,7 @@ from .nf4tensor import NF4Tensor, to_nf4 # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor -from .aqt import AffineQuantizedTensor, to_affine_quantized +from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized __all__ = [ "NF4Tensor", diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/affine_quantized_tensor.py similarity index 100% rename from torchao/dtypes/aqt.py rename to torchao/dtypes/affine_quantized_tensor.py diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 20e9ed3c7e..a6e95d0bed 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -80,7 +80,7 @@ from torch._inductor.runtime.runtime_utils import do_bench_gpu import copy from torchao.quantization.quant_api import ( quantize, - int4wo, + int4_weight_only, ) class ToyLinearModel(torch.nn.Module): @@ -104,8 +104,8 @@ example_inputs = m.example_inputs(dtype=dtype, device="cuda") m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) -groupsize = 32 -m = quantize(m, int4wo(groupsize=groupsize)) +group_size = 32 +m = quantize(m, int4_weight_only(group_size=group_size)) torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.use_mixed_mm = True @@ -152,7 +152,7 @@ for n, m in model.named_modules(): The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support `torch.export.export` and `torch.aot_compile` with the following workaround: ``` -from torchao.quantization.utils import unwrap_tensor_subclass +from torchao.utils import unwrap_tensor_subclass m_unwrapped = unwrap_tensor_subclass(m) @@ -169,11 +169,10 @@ torch._export.aot_compile(m_unwrapped, example_inputs) ```python # Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor torch._inductor.config.force_fuse_int_mm_with_mul = True -from torchao.quantization import quant_api # for torch 2.4+ -from torchao.quantization.quant_api import quantize -quantize(model, "int8_dynamic") +from torchao.quantization import quantize, int8_dynamic_activation_int8_weight +quantize(model, int8_dynamic_activation_int8_weight()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -184,9 +183,8 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ -from torchao.quantization.quant_api import quantize -from torchao.quantization.quant_api import int8wo -quantize(model, "int8_weight_only") +from torchao.quantization import quantize, int8_weight_only +quantize(model, int8_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -200,8 +198,8 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ```python # for torch 2.4+ -from torchao.quantization.quant_api import quantize -quantize(model, "int4_weight_only") +from torchao.quantization import quantize, int4_weight_only +quantize(model, int4_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index e461171d9e..f1bb82921e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -32,5 +32,8 @@ "dequantize_affine", "choose_qprams_affine", "quantize", - "register_apply_tensor_subclass", + "int8_dynamic_act_int4_weight", + "int8_dynamic_act_int8_weight", + "int4_weight_only", + "int8_weight_only", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2a65d3c831..3a1516d9b5 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,11 +54,10 @@ "autoquant", "_get_subclass_inserter", "quantize", - "int8da_int4w", - "int8da_int8w", - "int4wo", - "int8wo", - "register_apply_tensor_subclass", + "int8_dynamic_activation_int4_weight", + "int8_dynamic_activation_int8_weight", + "int4_weight_only", + "int8_weight_only", ] from .GPTQ import ( @@ -259,13 +258,12 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Union[str, Callable[[torch.Tensor], torch.Tensor]], filter_fn=None) -> torch.nn.Module: +def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` Args: model: input model apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance - or a string filter_fn: used to filter out the modules that we don't want to apply tenosr subclass Example:: @@ -307,22 +305,22 @@ def filter_fn(module, fqn): ) return model -def int8da_int4w(groupsize=32): +def int8_dynamic_activation_int4_weight(group_size=32): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear This is used to produce a model for executorch backend, but currently executorch did not support lowering for the quantized model from this flow yet Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained """ - def apply_8da4w_quant(weight): + def apply_int8_dynamic_activation_int4_weight_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized # weight settings mapping_type = MappingType.SYMMETRIC - block_size = (1, groupsize) + block_size = (1, group_size) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps quant_min = -8 @@ -346,25 +344,25 @@ def get_per_token_block_size(x): weight = to_linear_act_quantized(weight, input_quant_func) return weight - return apply_8da4w_quant + return apply_int8_dynamic_activation_int4_weight_quant -def int4wo(groupsize=128, inner_k_tiles=8): +def int4_weight_only(group_size=128, inner_k_tiles=8): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] """ - def apply_int4wo_quant(weight): + def apply_int4_weight_only_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized mapping_type = MappingType.ASYMMETRIC - block_size = (1, groupsize) + block_size = (1, group_size) target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -374,10 +372,10 @@ def apply_int4wo_quant(weight): zero_point_domain = ZeroPointDomain.FLOAT return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles) - return apply_int4wo_quant + return apply_int4_weight_only_quant -def int8wo(): +def int8_weight_only(): """ Applies int8 weight-only symmetric per-channel quantization to linear layers. """ @@ -391,14 +389,15 @@ def apply_int8wo_quant(weight): zero_point_dtype = torch.int64 block_size = (1, weight.shape[1]) return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return apply_int8wo_quant -def int8da_int8w(): +def int8_dynamic_activation_int8_weight(): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - def apply_int8dyn_quant(weight): + def apply_int8_dynamic_activation_int8_weight_quant(weight): in_features = weight.shape[1] # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: @@ -432,28 +431,5 @@ def get_per_token_block_size(x): weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) weight = to_linear_act_quantized(weight, input_quant_func) return weight - return apply_int8dyn_quant - -# shortcut string to apply_tensor_subclass with a specific setting -# to simplify common use cases -_APPLY_TS_TABLE: Dict[str, Callable] = { - "int4_weight_only": int4wo(), - "int8_weight_only": int8wo(), - "int8_dynamic": int8da_int8w(), -} - -def register_apply_tensor_subclass(name: str, apply_tensor_subclass: Callable): - """Register a string shortcut for `apply_tensor_subclass` that takes a weight Tensor - as input and ouptuts a tensor with tensor subclass applied - - Example: - def apply_my_dtype(weight): - return weight * 2 - - register_apply_tensor_subclass("my_dtype", apply_my_dtype) - # calls `apply_my_dtype` on weights - quantize(m, "my_dtype") - """ - if name in _APPLY_TS_TABLE: - logging.warning(f"shortcut string {name} already exist, overwriting") - _APPLY_TS_TABLE[name] = apply_tensor_subclass + + return apply_int8_dynamic_activation_int8_weight_quant diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 6115e5f21d..07e0118d20 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -20,7 +20,8 @@ # for torch 2.4+ from torchao.quantization.quant_api import quantize -quantize(model, "int8_dynamic") +from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight +quantize(model, int8_dynamic_activation_int8_weight()) ## Quantization code - end ## compilation configs