diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 72ffc23ab6..e1e670d5da 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,15 +18,11 @@ from torchao.quantization.prototype.qat.api import ( ComposableQATQuantizer, ) -from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, -) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, _GenericFakeQuantize, - _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, ) from torchao.quantization.quant_api import ( int4_weight_only, @@ -164,7 +160,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat._module_swap_api import ( + from torchao.quantization.prototype.qat.linear import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -196,7 +192,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATLinear + from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -219,45 +215,17 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - # TODO: compare against quantize_ API instead @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer - - group_size = 16 - torch.manual_seed(self.SEED) - m = M() - m2 = copy.deepcopy(m) - qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) - qat_model = qat_quantizer.prepare(m) - ptq_model = ptq_quantizer.quantize(m2) - - # Compare model values - torch.manual_seed(self.SEED) - x = m.example_inputs() - x2 = copy.deepcopy(x) - qat_out = qat_model(*x) - ptq_out = ptq_model(*x2) - torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - - # Convert QAT model and compare model values - converted_model = qat_quantizer.convert(qat_model) - converted_out = converted_model(*x) - torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") - def test_qat_8da4w_quantizer_module_swap(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap + from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer group_size = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap(groupsize=group_size) + module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) subclass_model = subclass_quantizer.prepare(m) module_swap_model = module_swap_quantizer.prepare(m2) @@ -288,20 +256,6 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - def _copy_subclass_weights( - self, - nn_linear: torch.nn.Linear, - subclass_linear: AffineFakeQuantizedTensor, - ): - nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor) - - def _assert_matches_subclass_weights( - self, - nn_linear: torch.nn.Linear, - subclass_linear: AffineFakeQuantizedTensor, - ): - torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ @@ -313,16 +267,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): enable_8da4w_fake_quant, ) - def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): - self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor)) - self.assertEqual(m.weight.fake_quant_enabled, enabled) - self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)) - (_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) - if enabled: - self.assertIsNotNone(handle) - else: - self.assertIsNone(handle) - group_size = 16 torch.manual_seed(self.SEED) m = M() @@ -331,14 +275,14 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - assert_fake_quant_enabled(qat_model.linear1, enabled=False) - assert_fake_quant_enabled(qat_model.linear2, enabled=False) - assert_fake_quant_enabled(qat_model.sub.linear, enabled=False) + self.assertFalse(qat_model.linear1._fake_quant_enabled) + self.assertFalse(qat_model.linear2._fake_quant_enabled) + self.assertFalse(qat_model.sub.linear._fake_quant_enabled) # Disabled fake quant is just a normal linear - self._copy_subclass_weights(m2.linear1, qat_model.linear1) - self._copy_subclass_weights(m2.linear2, qat_model.linear2) - self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear) + m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) + m2.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight) + m2.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -348,16 +292,16 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - assert_fake_quant_enabled(qat_model.linear1, enabled=True) - assert_fake_quant_enabled(qat_model.linear2, enabled=True) - assert_fake_quant_enabled(qat_model.sub.linear, enabled=True) + self.assertTrue(qat_model.linear1._fake_quant_enabled) + self.assertTrue(qat_model.linear2._fake_quant_enabled) + self.assertTrue(qat_model.sub.linear._fake_quant_enabled) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model2 = quantizer2.prepare(m3) - qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor - qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor - qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor + qat_model2.linear1.weight = qat_model.linear1.weight + qat_model2.linear2.weight = qat_model.linear2.weight + qat_model2.sub.linear.weight = qat_model.sub.linear.weight torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -382,9 +326,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self._copy_subclass_weights(nn_model.linear1, qat_model.linear1) - self._copy_subclass_weights(nn_model.linear2, qat_model.linear2) - self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + nn_model.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) + nn_model.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight) + nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) # Simulate training for both models optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) @@ -406,9 +350,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1) - self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2) - self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) def _test_qat_quantized_gradients(self, quantizer): """ @@ -542,7 +486,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATLinear + from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -567,39 +511,6 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer - - group_size = 32 - inner_k_tiles = 8 - device = torch.device("cuda") - dtype = torch.bfloat16 - torch.manual_seed(self.SEED) - m = M().to(device).to(dtype) - m2 = copy.deepcopy(m) - qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, - ) - qat_model = qat_quantizer.prepare(m) - ptq_model = m2 - quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles))) - - # Compare model values - torch.manual_seed(self.SEED) - x = [i.to(device).to(dtype) for i in m.example_inputs()] - x2 = copy.deepcopy(x) - qat_out = qat_model(*x) - ptq_out = ptq_model(*x2) - self._assert_close_4w(qat_out, ptq_out) - - # Convert QAT model and compare model values - converted_model = qat_quantizer.convert(qat_model) - converted_out = converted_model(*x) - torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_quantizer_gradients(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer @@ -608,9 +519,9 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - def test_qat_4w_quantizer_module_swap(self): + def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap + from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 @@ -622,7 +533,7 @@ def test_qat_4w_quantizer_module_swap(self): subclass_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap( + module_swap_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) subclass_model = subclass_quantizer.prepare(m) diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index 62740839b7..09ea6e708d 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -1,17 +1,14 @@ from .api import ( + ComposableQATQuantizer, +) +from .linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, - int4_weight_only_fake_quantize, - int8_dynamic_activation_int4_weight_fake_quantize, - ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, -) - -from ._module_swap_api import ( Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, ) from .embedding import ( Int4WeightOnlyEmbeddingQATQuantizer, @@ -22,8 +19,6 @@ "disable_8da4w_fake_quant", "enable_4w_fake_quant", "enable_8da4w_fake_quant", - "int4_weight_only_fake_quantize", - "int8_dynamic_activation_int4_weight_fake_quantize", "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int4WeightOnlyEmbeddingQATQuantizer" diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py index a9239a03d5..0b44974f21 100644 --- a/torchao/quantization/prototype/qat/_module_swap_api.py +++ b/torchao/quantization/prototype/qat/_module_swap_api.py @@ -1,355 +1,11 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any - -import torch -import torch.nn.functional as F - -from torchao.quantization.GPTQ import ( - _check_linear_int4_k, - _replace_linear_int4, - _replace_linear_8da4w, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor, - Int8DynActInt4WeightLinear, - WeightOnlyInt4Linear, -) -from torchao.quantization.quant_primitives import ZeroPointDomain -from torchao.quantization.utils import get_group_qparams_symmetric -from .api import ( - Int8DynActInt4WeightQATQuantizer, - Int4WeightOnlyQATQuantizer, -) -from .utils import ( - _choose_qparams_per_token_asymmetric, - _fake_quantize_per_channel_group, - _fake_quantize_per_token, - _get_qmin_qmax, +# For backward compatibility only +# These will be removed in the future + +from .linear import ( + Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap, + Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap, + enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap, + disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap, + enable_4w_fake_quant as enable_4w_fake_quant_module_swap, + disable_4w_fake_quant as disable_4w_fake_quant_module_swap, ) - - -# TODO: make module swap the main flow again, and remove the quantize_ flow -# TODO: rename this file to linear.py - -# ========================================================= -# | Linear int8 dynamic activations + int4 weight QAT | -# ========================================================= - - -class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have int8 - dynamic per token fake quantized activations and int4 fake quantized - grouped per channel weights. - """ - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_8da4w( - model, - self.groupsize, - self.padding_allowed, - self.precision, - self.scales_precision, - Int8DynActInt4WeightQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) - return model - - -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - - -class Int8DynActInt4WeightQATLinear(torch.nn.Linear): - """ - This module implements a linear layer with int8 dynamic per token fake - quantized activations with int4 fake quantized grouped per channel weights. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - device=device, - dtype=precision, - ) - assert ( - in_features % groupsize == 0 - ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" - assert not bias, "require bias=False" - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - # TODO: make this configurable? - self.zero_points_precision = torch.int32 - self._fake_quant_enabled = True - - def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # activations: int8 dynamic asymmetric quant - if self._fake_quant_enabled: - (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( - x, self.scales_precision, self.zero_points_precision, - ) - (act_qmin, act_qmax) = _get_qmin_qmax(8) - x_fq = _fake_quantize_per_token( - x, act_scales, act_zp, act_qmin, act_qmax, - ) - else: - x_fq = x - - # weights: int4 grouped per channel symmetric quant - if self._fake_quant_enabled: - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 4, self.groupsize, self.scales_precision, - ) - # TODO: pass zp dtype to `get_group_qparams_symmetric` instead - weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.groupsize, - ) - else: - w_fq = self.weight - return F.linear(x_fq, w_fq) - - -def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.enable_fake_quant() - - -def disable_8da4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.disable_fake_quant() - - -# =================================== -# | Linear int4 weight-only QAT | -# =================================== - - -class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have - int4 fake quantized grouped per channel weights. - """ - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_int4( - model, - self.groupsize, - self.inner_k_tiles, - padding_allowed=True, - precision=self.precision, - scales_precision=self.scales_precision, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _convert_qat_linear_4w(model) - return model - - -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - - -class Int4WeightOnlyQATLinear(torch.nn.Linear): - """ - This module implements a linear layer with int4 fake quantized grouped - per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, - which uses the efficient int4 tinygemm kernel. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - inner_k_tiles: int = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - device=device, - dtype=precision, - ) - assert not bias, "require bias=False" - assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" - if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): - raise ValueError("Padding for QAT 4w is not supported yet") - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.precision = precision - self.scales_precision = scales_precision - self._fake_quant_enabled = True - - def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - n_bit = 4 - qmin = 0 - qmax = 2 ** n_bit - 1 - scales, zero_points = get_groupwise_affine_qparams( - self.weight, n_bit, self.groupsize, self.scales_precision, - ) - w_fq = _fake_quantize_per_channel_group( - self.weight, - scales, - zero_points, - qmin, - qmax, - self.groupsize, - ZeroPointDomain.FLOAT, - ) - return F.linear(x, w_fq) - - -def enable_4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.enable_fake_quant() - - -def disable_4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.disable_fake_quant() diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index e1c5221e1e..93717271bb 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,34 +4,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Optional +from typing import Any, List import torch -import torch.nn.functional as F -from torchao.dtypes import ( - TensorCoreTiledLayoutType, -) -from torchao.quantization.quant_api import ( - _get_linear_subclass_inserter, - _replace_with_custom_fn_if_matches_filter, - int4_weight_only, - int8_dynamic_activation_int4_weight, - quantize_, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import _get_per_token_block_size -from .affine_fake_quantized_tensor import to_affine_fake_quantized -from .utils import ( - _enable_fake_quant, - _get_qat_linear_subclass_inserter, - _is_linear_with_fq_weight, - _unwrap_affine_fake_quantized_tensor, -) class ComposableQATQuantizer(TwoStepQuantizer): @@ -70,207 +47,3 @@ def convert( for quantizer in self.quantizers: model = quantizer.convert(model) return model - - -# ================= -# | 8da4w QAT | -# ================= - -def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): - """ - Applies int8 dynamic per token asymmetric activation fake quantization and - int4 per group weight symmetric fake quantization to linear. Please see - :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. - - Example usage:: - - from torchao.quantization import quantize_ - quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) - """ - # avoid circular dep - from torchao.dtypes import to_affine_quantized_intx - - def _apply_weight_fake_quant(weight: torch.Tensor): - mapping_type = MappingType.SYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - quant_min = -8 - quant_max = 7 - return to_affine_fake_quantized( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - ) - - def _apply_input_activation_fake_quant(x: torch.Tensor): - mapping_type = MappingType.ASYMMETRIC - target_dtype = torch.int8 - return to_affine_fake_quantized( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - ) - - return _get_qat_linear_subclass_inserter( - _apply_weight_fake_quant, - _apply_input_activation_fake_quant, - ) - -class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have int8 - dynamic per token fake quantized activations and int4 fake quantized - grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - padding_allowed: bool = False, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__() - self.groupsize: int = groupsize - self.padding_allowed: bool = padding_allowed - self.precision: torch.dtype = precision - self.scales_precision: torch.dtype = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - quantize_( - model, - int8_dynamic_activation_int4_weight_fake_quantize(group_size=self.groupsize), - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) - filter_fn = _is_linear_with_fq_weight - model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) - quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) - quantize_(model, quantize_fn) - return model - - -def enable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for int8 dynamic activations + int4 weight. - """ - _enable_fake_quant(mod, enable=True) - -def disable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for int8 dynamic activations + int4 weight. - """ - _enable_fake_quant(mod, enable=False) - - -# ================== -# | int4wo QAT | -# ================== - -def int4_weight_only_fake_quantize(group_size=128): - """ - Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. - Please see :func:`~torchao.quantization.int4_weight_only` for more details. - - Example usage:: - - from torchao.quantization import quantize_ - quantize_(model, int4_weight_only_fake_quantize(group_size=32)) - """ - def _apply_fake_quant(weight): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_fake_quantized( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - ) - return _get_qat_linear_subclass_inserter(_apply_fake_quant) - -class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have - int4 fake quantized grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - inner_k_tiles: Optional[int] = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__() - assert inner_k_tiles in [2, 4, 8] - assert groupsize in [32, 64, 128, 256] - self.inner_k_tiles = inner_k_tiles - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - quantize_(model, int4_weight_only_fake_quantize(group_size=self.groupsize)) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) - filter_fn = _is_linear_with_fq_weight - model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) - layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) - quantize_fn = int4_weight_only(self.groupsize, layout_type) - quantize_(model, quantize_fn) - return model - -def enable_4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for int4 weight only. - """ - _enable_fake_quant(mod, enable=True) - -def disable_4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for int4 weight only. - """ - _enable_fake_quant(mod, enable=False) diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py new file mode 100644 index 0000000000..07276ba84c --- /dev/null +++ b/torchao/quantization/prototype/qat/linear.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from torchao.quantization.GPTQ import ( + _check_linear_int4_k, + _replace_linear_int4, + _replace_linear_8da4w, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, + Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, +) +from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.utils import get_group_qparams_symmetric +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +# ========================================================= +# | Linear int8 dynamic activations + int4 weight QAT | +# ========================================================= + + +class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have int8 + dynamic per token fake quantized activations and int4 fake quantized + grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.scales_precision: torch.dtype = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.scales_precision, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_8da4w(model) + return model + + +def _convert_qat_linear_8da4w(module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=child.groupsize, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = _get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + _convert_qat_linear_8da4w(child) + + +class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int8 dynamic per token fake + quantized activations with int4 fake quantized grouped per channel weights. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + assert not bias, "require bias=False" + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + # TODO: make this configurable? + self.zero_points_precision = torch.int32 + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # activations: int8 dynamic asymmetric quant + if self._fake_quant_enabled: + (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( + x, self.scales_precision, self.zero_points_precision, + ) + (act_qmin, act_qmax) = _get_qmin_qmax(8) + x_fq = _fake_quantize_per_token( + x, act_scales, act_zp, act_qmin, act_qmax, + ) + else: + x_fq = x + + # weights: int4 grouped per channel symmetric quant + if self._fake_quant_enabled: + (weight_scales, weight_zp) = get_group_qparams_symmetric( + self.weight, 4, self.groupsize, self.scales_precision, + ) + # TODO: pass zp dtype to `get_group_qparams_symmetric` instead + weight_zp = weight_zp.to(self.zero_points_precision) + (weight_qmin, weight_qmax) = _get_qmin_qmax(4) + w_fq = _fake_quantize_per_channel_group( + self.weight, + weight_scales, + weight_zp, + weight_qmin, + weight_qmax, + self.groupsize, + ) + else: + w_fq = self.weight + return F.linear(x_fq, w_fq) + + +def enable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.enable_fake_quant() + + +def disable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.disable_fake_quant() + + +# =================================== +# | Linear int4 weight-only QAT | +# =================================== + + +class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + inner_k_tiles: Optional[int] = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + self.inner_k_tiles = inner_k_tiles + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + padding_allowed=True, + precision=self.precision, + scales_precision=self.scales_precision, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_4w(model) + return model + + +def _convert_qat_linear_4w(module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + groupsize = child.groupsize + inner_k_tiles = child.inner_k_tiles + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, child.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + _convert_qat_linear_4w(child) + + +class Int4WeightOnlyQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int4 fake quantized grouped + per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, + which uses the efficient int4 tinygemm kernel. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert not bias, "require bias=False" + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_bit = 4 + qmin = 0 + qmax = 2 ** n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + self.weight, n_bit, self.groupsize, self.scales_precision, + ) + w_fq = _fake_quantize_per_channel_group( + self.weight, + scales, + zero_points, + qmin, + qmax, + self.groupsize, + ZeroPointDomain.FLOAT, + ) + return F.linear(x, w_fq) + + +def enable_4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.enable_fake_quant() + + +def disable_4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.disable_fake_quant() diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 1e4b61b8ac..354475e655 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -181,85 +181,6 @@ def _choose_qparams_per_token_asymmetric( return scale.to(scales_precision), zero_point.to(zero_points_precision) -def _forward_pre_hook_handler( - mod: torch.nn.Linear, - prehook: Callable, - handler: torch.utils.hooks.RemovableHandle, -): - """ - Store a 2-tuple (prehook function, handler) as an attribute on the given linear module. - """ - setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handler)) - -def _unwrap_affine_fake_quantized_tensor(t: torch.Tensor): - """ - Return the original, non-fake-quantized float tensor from a `AffineFakeQuantizedTensor`. - """ - # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - assert isinstance(t, AffineFakeQuantizedTensor) - return t.original_tensor - -def _is_linear_with_fq_weight(mod: torch.nn.Module, *args): - """ - Return whether this is a nn.Linear module with `AffineFakeQuantizeTensor` weights. - """ - # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - if not isinstance(mod, torch.nn.Linear) or not hasattr(mod, "weight"): - return False - weight = mod.weight - return isinstance(weight, AffineFakeQuantizedTensor) - -def _enable_fake_quant(mod: torch.nn.Module, enable: bool): - """ - Enable or disable fake quantization in the activations and weights of a `nn.Linear` module. - """ - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - if not _is_linear_with_fq_weight(mod): - return - weight = mod.weight - assert isinstance(weight, AffineFakeQuantizedTensor) - weight.fake_quant_enabled = enable - - # Enable/disable input fake quant - if hasattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK): - (prehook, handle) = getattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) - if enable and handle is None: - handle = mod.register_forward_pre_hook(prehook) - elif not enable and handle is not None: - handle.remove() - handle = None - setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) - -def _get_qat_linear_subclass_inserter( - weight_constructor: Callable, - input_constructor: Optional[Callable] = None, -) -> Callable: - """ - Return a function that inserts wraps the weight and/or input activation of a - linear module in tensor subclasses. - - Args: - weight_constructor: constructor of the weight subclass, accepts a tensor - input_constructor: (optional) constructor of the input subclass, accepts a tensor - """ - def insert_subclass(lin): - lin.weight = torch.nn.Parameter(weight_constructor(lin.weight), requires_grad=True) - if input_constructor is not None: - prehook = lambda _, args: tuple([input_constructor(args[0])] + list(args[1:])) - handle = lin.register_forward_pre_hook(prehook) - setattr(lin, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) - return lin - - return insert_subclass - def _get_qmin_qmax(n_bit: int): qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1