From 04ca7432f91d2871c6a88d9743626a595f0fa409 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sun, 20 Apr 2025 21:10:57 -0700 Subject: [PATCH 01/13] init --- .../tests/test_embedding_xbit_quantizer.py | 9 ++++---- ...est_int8_dynamic_activation_intx_weight.py | 12 ---------- torchao/quantization/GPTQ.py | 5 ++++- torchao/quantization/qat/embedding.py | 22 +++++++++++++------ torchao/quantization/qat/linear.py | 4 ++-- torchao/quantization/utils.py | 3 ++- 6 files changed, 27 insertions(+), 28 deletions(-) diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 511a2d2c9f..31bcec1b8a 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -346,6 +346,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer( zero_point_precision=torch.int32, ) model = qat_quantizer.prepare(model) + prepared_model_copy = copy.deepcopy(model) expected_out = model(indices) # Convert model method 1 @@ -363,12 +364,10 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer( actual_out1 = model(indices) self.assertTrue(torch.allclose(expected_out, actual_out1)) - # TODO: method 2 does not work because the converted embedding op - # incorrectly casts output of to indices.dtype # Convert model method 2 - # qat_quantizer.convert(prepared_model_copy) - # actual_out2 = prepared_model_copy(indices) - # self.assertTrue(torch.allclose(expected_out, actual_out2)) + qat_quantizer.convert(prepared_model_copy) + actual_out2 = prepared_model_copy(indices) + self.assertTrue(torch.allclose(expected_out, actual_out2)) if __name__ == "__main__": diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index b217aa349e..e87305bb0b 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -510,11 +510,6 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( if mapping_type == MappingType.ASYMMETRIC: return - # TODO: QAT logic for non-float32 models does not match PTQ right now - # QAT's default scale-precision is float32, but PTQ's is None (which defaults to input's dtype) - if model_dtype != torch.float32: - return - assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] assert act_mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] is_symmetric = mapping_type == MappingType.SYMMETRIC @@ -587,13 +582,6 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( def test_identical_to_Int8DynActInt4WeightQATQuantizer( self, group_size, scale_dtype, model_dtype ): - # Currently this does not match - # TODO: investigat - if scale_dtype != torch.float32: - return - if model_dtype != torch.float32: - return - k0 = 512 k1 = 256 layers = [ diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index ee2ad57c08..95bd014efd 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -933,7 +933,10 @@ def linear_forward_8da4w( groupsize, precision, ): - x = per_token_dynamic_quant(x, scale_dtype=precision, zero_point_dtype=precision) + # to match torchao.quantization.quant_api._int8_asymm_per_token_quant + x = per_token_dynamic_quant( + x, scale_dtype=torch.float64, zero_point_dtype=torch.int64 + ) # TODO: verify and remove following reshape code # origin_x_size = x.size() # x = x.reshape(-1, origin_x_size[-1]) diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 42e9b08eed..0146706853 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -177,6 +177,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: scale_precision=self.scale_precision, zero_point_precision=self.zero_point_precision, device=child.weight.device, + dtype=child.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to @@ -227,13 +228,14 @@ def _convert_helper(self, module: torch.nn.Module): scale_precision=scale_precision, zero_point_precision=zero_point_precision, device=child.weight.device, + output_dtype=child.weight.dtype, ) setattr(module, name, quantized_embedding) # Load weights and qparams into quantized embedding (qmin, qmax) = _get_qmin_qmax(self.bit_width) (s, zp) = get_group_qparams_symmetric( - child.weight, self.bit_width, group_size + child.weight, self.bit_width, group_size, precision=scale_precision, ) q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( child.weight, @@ -324,6 +326,7 @@ def __init__( scale_precision: torch.dtype = torch.float32, zero_point_precision: torch.dtype = torch.int32, device: torch.device = None, + output_dtype: torch.dtype = torch.float32, ): super().__init__() @@ -341,6 +344,7 @@ def __init__( self.group_size = group_size self.scale_precision = scale_precision self.zero_point_precision = zero_point_precision + self.output_dtype = output_dtype # currently storing unpacked int8 weights self.register_buffer( @@ -367,20 +371,24 @@ def __init__( ) def forward(self, x): - from torchao._executorch_ops import ( - _quantized_decomposed_dequantize_per_channel_group_wrapper, + from torchao.quantization.quant_primitives import ( + dequantize_affine, ) qmin, qmax = _get_qmin_qmax(self.bit_width) - w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper( + + # dequantize_affine casts to output_dtype before scaling + # dequantize_per_channel_group scales and then casts to output_dtype + # The two do not agree when dtype != torch.float32 + w_dq = dequantize_affine( self.weight, + [1, self.group_size], self.scale, self.zero_point, + torch.int8, qmin, qmax, - torch.int8, - self.group_size, - x.dtype, + output_dtype=self.output_dtype, ) return F.embedding( x, diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 12584fade8..a5c23630be 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -219,7 +219,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) (s, zp) = get_group_qparams_symmetric( - child.weight, n_bit, config.group_size + child.weight, n_bit, config.group_size, precision=scale_precision, ) from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, @@ -270,7 +270,7 @@ def __init__( precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: - activation_config = _get_8da4w_activation_config(scales_precision) + activation_config = _get_8da4w_activation_config(torch.float32) weight_config = _get_8da4w_weight_config(groupsize, scales_precision) super().__init__( in_features, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 0c30fba713..fded0fd61f 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -555,7 +555,7 @@ def get_group_qparams_symmetric( quant_max=quant_max, eps=eps, scale_dtype=precision, - zero_point_dtype=precision, + zero_point_dtype=torch.int32, ) return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1) @@ -605,6 +605,7 @@ def per_token_dynamic_quant( quant_dtype, quant_min, quant_max, + eps=torch.finfo(torch.float32).eps, scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, ) From ba50fec9f63be636a92f75fffed694758fb56672 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 21 Apr 2025 09:58:38 -0700 Subject: [PATCH 02/13] up --- torchao/quantization/GPTQ.py | 5 ++++- torchao/quantization/quant_api.py | 10 ++++------ torchao/quantization/utils.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 95bd014efd..10cfae3a57 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -935,8 +935,11 @@ def linear_forward_8da4w( ): # to match torchao.quantization.quant_api._int8_asymm_per_token_quant x = per_token_dynamic_quant( - x, scale_dtype=torch.float64, zero_point_dtype=torch.int64 + x, scale_dtype=torch.float32, zero_point_dtype=torch.float32 ) + # x = per_token_dynamic_quant( + # x, scale_dtype=torch.float64, zero_point_dtype=torch.int64 + # ) # TODO: verify and remove following reshape code # origin_x_size = x.size() # x = x.reshape(-1, origin_x_size[-1]) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4b2cd53024..142358d563 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -625,14 +625,16 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: """This is defined here instead of local function to support serialization""" mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int8 + scale_dtype = torch.float32 + zero_point_dtype = torch.int32 if TORCH_VERSION_AT_LEAST_2_6: return to_affine_quantized_intx( x, mapping_type, _get_per_token_block_size(x), target_dtype, - scale_dtype=torch.float64, - zero_point_dtype=torch.int64, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, ) else: return to_affine_quantized_intx( @@ -706,7 +708,6 @@ def _int8_dynamic_activation_int4_weight_transform( # weight settings block_size = (1, group_size) target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps quant_min = -8 quant_max = 7 @@ -737,7 +738,6 @@ def _int8_dynamic_activation_int4_weight_transform( target_dtype, quant_min, quant_max, - eps, _layout=layout, ) weight = to_linear_activation_quantized(weight, input_quant_func) @@ -858,7 +858,6 @@ def _int8_dynamic_activation_intx_weight_transform( target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, - eps=torch.finfo(torch.float32).eps, scale_dtype=weight_scale_dtype, zero_point_dtype=torch.int8, preserve_zero=(weight_mapping_type == MappingType.SYMMETRIC), @@ -1895,7 +1894,6 @@ def _intx_weight_only_transform( target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, - eps=torch.finfo(torch.float32).eps, scale_dtype=scale_dtype, zero_point_dtype=torch.int8, preserve_zero=(mapping_type == MappingType.SYMMETRIC), diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index fded0fd61f..47d85e513c 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -555,7 +555,7 @@ def get_group_qparams_symmetric( quant_max=quant_max, eps=eps, scale_dtype=precision, - zero_point_dtype=torch.int32, + zero_point_dtype=precision, ) return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1) From c57544f37e5c4d03feb134e01d64c93aba1a29c1 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 21 Apr 2025 10:52:43 -0700 Subject: [PATCH 03/13] up --- .../tests/test_int8_dynamic_activation_intx_weight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index e87305bb0b..6e5361f09b 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -360,7 +360,7 @@ def test_export_QDQLayout(self): self.assertTrue(torch.allclose(eager_results, exported_results)) expected_lines = [ - "torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, None, torch.float64, torch.int64)", + "torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, None, torch.float32, torch.int32)", "torch.ops.torchao.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int8)", "torch.ops.torchao.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int8)", "torch.ops.torchao.dequantize_affine.default", From c7fb2d73eac0e90878133fc57535002290c18d9b Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 21 Apr 2025 10:58:39 -0700 Subject: [PATCH 04/13] up --- torchao/quantization/GPTQ.py | 6 ++---- torchao/quantization/utils.py | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 10cfae3a57..bc8f6a8c56 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -933,13 +933,11 @@ def linear_forward_8da4w( groupsize, precision, ): - # to match torchao.quantization.quant_api._int8_asymm_per_token_quant + # uses fp32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant + # and activation_scale_dtype in QAT configs x = per_token_dynamic_quant( x, scale_dtype=torch.float32, zero_point_dtype=torch.float32 ) - # x = per_token_dynamic_quant( - # x, scale_dtype=torch.float64, zero_point_dtype=torch.int64 - # ) # TODO: verify and remove following reshape code # origin_x_size = x.size() # x = x.reshape(-1, origin_x_size[-1]) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 47d85e513c..0c30fba713 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -605,7 +605,6 @@ def per_token_dynamic_quant( quant_dtype, quant_min, quant_max, - eps=torch.finfo(torch.float32).eps, scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, ) From 05e2d00afe57a2e62d03b7e119eb541106a5ef60 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:12:45 -0700 Subject: [PATCH 05/13] up --- torchao/experimental/quant_passes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/experimental/quant_passes.py b/torchao/experimental/quant_passes.py index ba5e21eca7..13a0a755fb 100644 --- a/torchao/experimental/quant_passes.py +++ b/torchao/experimental/quant_passes.py @@ -86,7 +86,7 @@ def _get_q_dq_linear_patterns_replacements_and_filters( glbs["a_quant_min"] = None glbs["a_quant_max"] = None glbs["a_mapping_type"] = "ASYMMETRIC" - glbs["a_scale_dtype"] = torch.float64 + glbs["a_scale_dtype"] = torch.float32 glbs["a_eps"] = None lcls = {} From 4d62fc34866fd1ee819578bb956c67c84cdb4dda Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 21 Apr 2025 13:51:41 -0700 Subject: [PATCH 06/13] up --- .../tests/test_int8_dynamic_activation_intx_weight.py | 2 +- torchao/quantization/quant_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 6e5361f09b..f22abf4a12 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -360,7 +360,7 @@ def test_export_QDQLayout(self): self.assertTrue(torch.allclose(eager_results, exported_results)) expected_lines = [ - "torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, None, torch.float32, torch.int32)", + "torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, None, torch.float32, torch.int8)", "torch.ops.torchao.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int8)", "torch.ops.torchao.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int8)", "torch.ops.torchao.dequantize_affine.default", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 142358d563..eb70f0b91f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -626,7 +626,7 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int8 scale_dtype = torch.float32 - zero_point_dtype = torch.int32 + zero_point_dtype = torch.int8 if TORCH_VERSION_AT_LEAST_2_6: return to_affine_quantized_intx( x, From 26fbfd2775f376d34474cf4d0ee0f12b412562a4 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 22 Apr 2025 20:38:49 -0700 Subject: [PATCH 07/13] up --- .../tests/test_embedding_xbit_quantizer.py | 29 +++++++++++------- ...est_int8_dynamic_activation_intx_weight.py | 30 +++++++++++-------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 31bcec1b8a..442612410e 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -32,6 +32,7 @@ MappingType, quantize_, ) +from torchao.quantization.utils import compute_error class TestEmbeddingQuantizer(unittest.TestCase): @@ -254,7 +255,7 @@ def test_identical_to_IntxWeightOnlyConfig( for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)] for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] - for model_dtype in [torch.float32, torch.bfloat16] + for model_dtype in [torch.float32, torch.bfloat16, torch.float16] ], name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) @@ -292,7 +293,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( IntXQuantizationAwareTrainingConfig(weight_config=weight_config), embedding_filter, ) - expected_out = model(indices) + prepared_out = model(indices) quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) quantize_( @@ -305,8 +306,14 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ), embedding_filter, ) - actual_out = model(indices) - self.assertTrue(torch.allclose(expected_out, actual_out)) + converted_out = model(indices) + sqnr = compute_error(prepared_out, converted_out).item() + + # For torch.int1, sometimes sqnr is nan because both tensors are all 0 + # so we check torch.equal as well + self.assertTrue( + sqnr == float("inf") or torch.equal(prepared_out, converted_out) + ) @parameterized.expand( [ @@ -317,7 +324,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ) for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)] for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] - for model_dtype in [torch.float32, torch.bfloat16] + for model_dtype in [torch.float32, torch.bfloat16, torch.float16] ], name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) @@ -347,7 +354,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer( ) model = qat_quantizer.prepare(model) prepared_model_copy = copy.deepcopy(model) - expected_out = model(indices) + prepared_out = model(indices) # Convert model method 1 quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) @@ -361,13 +368,15 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer( ), embedding_filter, ) - actual_out1 = model(indices) - self.assertTrue(torch.allclose(expected_out, actual_out1)) + converted_out1 = model(indices) + sqnr1 = compute_error(prepared_out, converted_out1).item() + self.assertTrue(sqnr1 == float("inf")) # Convert model method 2 qat_quantizer.convert(prepared_model_copy) - actual_out2 = prepared_model_copy(indices) - self.assertTrue(torch.allclose(expected_out, actual_out2)) + converted_out2 = prepared_model_copy(indices) + sqnr2 = compute_error(prepared_out, converted_out2).item() + self.assertTrue(sqnr2 == float("inf")) if __name__ == "__main__": diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index f22abf4a12..da6c98cd6f 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -26,6 +26,7 @@ MappingType, quantize_, ) +from torchao.quantization.utils import compute_error class TestInt8DynamicActivationIntxWeight(unittest.TestCase): @@ -475,7 +476,8 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig( ), ) with torch.no_grad(): - torch.allclose(model(activations), model_copy(activations)) + sqnr = compute_error(model(activations), model_copy(activations)).item() + self.assertTrue(sqnr == float("inf")) @parameterized.expand( [ @@ -492,7 +494,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig( for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] for act_mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] - for model_dtype in [torch.float32, torch.bfloat16] + for model_dtype in [torch.float32, torch.bfloat16, torch.float16] ], name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) @@ -545,7 +547,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) try: - expected_out = model(activations) + prepared_out = model(activations) except NotImplementedError as e: # QAT does not support act_mapping_type == MappingType.SYMMETRIC yet if act_mapping_type == MappingType.SYMMETRIC: @@ -563,8 +565,10 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( act_mapping_type=act_mapping_type, ), ) - actual_out = model(activations) - self.assertTrue(torch.allclose(expected_out, actual_out)) + converted_out = model(activations) + + sqnr = compute_error(prepared_out, converted_out).item() + self.assertTrue(sqnr == float("inf")) @parameterized.expand( [ @@ -575,7 +579,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ) for group_size in [32, 64, 128] for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] - for model_dtype in [torch.float32, torch.bfloat16] + for model_dtype in [torch.float32, torch.bfloat16, torch.float16] ], name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) @@ -599,10 +603,10 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( groupsize=group_size, precision=model_dtype, scales_precision=scale_dtype ) model = qat_quantizer.prepare(model) - expected_out = model(activations) - prepared_model_copy = copy.deepcopy(model) + prepared_out = model(activations) + # Convert model method 1 quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_( @@ -615,13 +619,15 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( act_mapping_type=MappingType.ASYMMETRIC, ), ) - actual_out1 = model(activations) - self.assertTrue(torch.allclose(expected_out, actual_out1)) + converted_out1 = model(activations) + sqnr1 = compute_error(prepared_out, converted_out1).item() + self.assertTrue(sqnr1 == float("inf")) # Convert model method 2 qat_quantizer.convert(prepared_model_copy) - actual_out2 = prepared_model_copy(activations) - self.assertTrue(torch.allclose(expected_out, actual_out2)) + converted_out2 = prepared_model_copy(activations) + sqnr2 = compute_error(prepared_out, converted_out2).item() + self.assertTrue(sqnr2 == float("inf")) if __name__ == "__main__": From 808372fdb458c9b745ba4eee42196bfe7cfbfc42 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 22 Apr 2025 20:41:02 -0700 Subject: [PATCH 08/13] up --- torchao/experimental/op_lib.py | 6 +++--- torchao/experimental/quant_api.py | 28 +++++++++++++--------------- torchao/experimental/quant_passes.py | 12 ++++++------ 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/torchao/experimental/op_lib.py b/torchao/experimental/op_lib.py index 4fe478d1e8..716ae469f0 100644 --- a/torchao/experimental/op_lib.py +++ b/torchao/experimental/op_lib.py @@ -13,9 +13,9 @@ # Load C++ ops lib_path = Path(__file__).parent.parent libs = list(lib_path.glob("libtorchao_ops_aten.*")) -assert len(libs) == 1, ( - f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}" -) +assert ( + len(libs) == 1 +), f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}" torch.ops.load_library(str(libs[0])) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index b7630cada3..6e78e3d853 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -225,12 +225,12 @@ def _replace_embedding_with_quantized_embedding( packed_weight = weight_tensor.tensor_impl.packed_weight bit_width = weight_tensor.tensor_impl.get_layout().bit_width - assert n == child.num_embeddings, ( - "num_embeddings must match n in shared_unembedding" - ) - assert k == child.embedding_dim, ( - "embedding_dim must match k in shared_unembedding" - ) + assert ( + n == child.num_embeddings + ), "num_embeddings must match n in shared_unembedding" + assert ( + k == child.embedding_dim + ), "embedding_dim must match k in shared_unembedding" qembedding = QuantizedSharedEmbedding( bit_width, packed_weight, @@ -420,18 +420,16 @@ def quantize( # Check that embeddings are shared, embeddings are embeddings, and unembeddings are linear ops for embedding_fqn, unembedding_fqn in embedding_to_unembedding.items(): - assert embedding_fqn in embedding_fqns, ( - f"Embedding {embedding_fqn} is not found in model" - ) - assert unembedding_fqn in linear_fqns, ( - f"Unembedding {unembedding_fqn} is not found in model" - ) + assert ( + embedding_fqn in embedding_fqns + ), f"Embedding {embedding_fqn} is not found in model" + assert ( + unembedding_fqn in linear_fqns + ), f"Unembedding {unembedding_fqn} is not found in model" assert torch.allclose( state_dict[embedding_fqn + ".weight"], state_dict[unembedding_fqn + ".weight"], - ), ( - f"Embedding {embedding_fqn} does not share weights with unembedding {unembedding_fqn}" - ) + ), f"Embedding {embedding_fqn} does not share weights with unembedding {unembedding_fqn}" # Quantize unembeddings quantize_( diff --git a/torchao/experimental/quant_passes.py b/torchao/experimental/quant_passes.py index 13a0a755fb..17c5d7ab03 100644 --- a/torchao/experimental/quant_passes.py +++ b/torchao/experimental/quant_passes.py @@ -195,9 +195,9 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass( """ # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ - assert len(ep.range_constraints) == 0, ( - "ExportedProgram with range constraints are not supported" - ) + assert ( + len(ep.range_constraints) == 0 + ), "ExportedProgram with range constraints are not supported" # ep.module() unlifts the weight inputs, which we need for constant folding gm = ep.module() @@ -295,9 +295,9 @@ def replace_q_dq_patterns_with_quantized_embedding_ops_pass( """ # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ - assert len(ep.range_constraints) == 0, ( - "ExportedProgram with range constraints are not supported" - ) + assert ( + len(ep.range_constraints) == 0 + ), "ExportedProgram with range constraints are not supported" # ep.module() unlifts the weight inputs, which we need for constant folding gm = ep.module() From a64e271bcb221fe1b926cd1ee75900c5f4e8826d Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 22 Apr 2025 20:47:27 -0700 Subject: [PATCH 09/13] up --- torchao/experimental/op_lib.py | 6 +++--- torchao/experimental/quant_api.py | 28 +++++++++++++++------------- torchao/experimental/quant_passes.py | 12 ++++++------ 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/torchao/experimental/op_lib.py b/torchao/experimental/op_lib.py index 716ae469f0..4fe478d1e8 100644 --- a/torchao/experimental/op_lib.py +++ b/torchao/experimental/op_lib.py @@ -13,9 +13,9 @@ # Load C++ ops lib_path = Path(__file__).parent.parent libs = list(lib_path.glob("libtorchao_ops_aten.*")) -assert ( - len(libs) == 1 -), f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}" +assert len(libs) == 1, ( + f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}" +) torch.ops.load_library(str(libs[0])) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 6e78e3d853..b7630cada3 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -225,12 +225,12 @@ def _replace_embedding_with_quantized_embedding( packed_weight = weight_tensor.tensor_impl.packed_weight bit_width = weight_tensor.tensor_impl.get_layout().bit_width - assert ( - n == child.num_embeddings - ), "num_embeddings must match n in shared_unembedding" - assert ( - k == child.embedding_dim - ), "embedding_dim must match k in shared_unembedding" + assert n == child.num_embeddings, ( + "num_embeddings must match n in shared_unembedding" + ) + assert k == child.embedding_dim, ( + "embedding_dim must match k in shared_unembedding" + ) qembedding = QuantizedSharedEmbedding( bit_width, packed_weight, @@ -420,16 +420,18 @@ def quantize( # Check that embeddings are shared, embeddings are embeddings, and unembeddings are linear ops for embedding_fqn, unembedding_fqn in embedding_to_unembedding.items(): - assert ( - embedding_fqn in embedding_fqns - ), f"Embedding {embedding_fqn} is not found in model" - assert ( - unembedding_fqn in linear_fqns - ), f"Unembedding {unembedding_fqn} is not found in model" + assert embedding_fqn in embedding_fqns, ( + f"Embedding {embedding_fqn} is not found in model" + ) + assert unembedding_fqn in linear_fqns, ( + f"Unembedding {unembedding_fqn} is not found in model" + ) assert torch.allclose( state_dict[embedding_fqn + ".weight"], state_dict[unembedding_fqn + ".weight"], - ), f"Embedding {embedding_fqn} does not share weights with unembedding {unembedding_fqn}" + ), ( + f"Embedding {embedding_fqn} does not share weights with unembedding {unembedding_fqn}" + ) # Quantize unembeddings quantize_( diff --git a/torchao/experimental/quant_passes.py b/torchao/experimental/quant_passes.py index 17c5d7ab03..13a0a755fb 100644 --- a/torchao/experimental/quant_passes.py +++ b/torchao/experimental/quant_passes.py @@ -195,9 +195,9 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass( """ # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ - assert ( - len(ep.range_constraints) == 0 - ), "ExportedProgram with range constraints are not supported" + assert len(ep.range_constraints) == 0, ( + "ExportedProgram with range constraints are not supported" + ) # ep.module() unlifts the weight inputs, which we need for constant folding gm = ep.module() @@ -295,9 +295,9 @@ def replace_q_dq_patterns_with_quantized_embedding_ops_pass( """ # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ - assert ( - len(ep.range_constraints) == 0 - ), "ExportedProgram with range constraints are not supported" + assert len(ep.range_constraints) == 0, ( + "ExportedProgram with range constraints are not supported" + ) # ep.module() unlifts the weight inputs, which we need for constant folding gm = ep.module() From 24e4f7ca8a2f98b84f38bea33490b7c63d8cc881 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 23 Apr 2025 11:18:47 -0700 Subject: [PATCH 10/13] up --- torchao/quantization/qat/embedding.py | 5 ++++- torchao/quantization/qat/linear.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 0146706853..7f252c15d6 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -235,7 +235,10 @@ def _convert_helper(self, module: torch.nn.Module): # Load weights and qparams into quantized embedding (qmin, qmax) = _get_qmin_qmax(self.bit_width) (s, zp) = get_group_qparams_symmetric( - child.weight, self.bit_width, group_size, precision=scale_precision, + child.weight, + self.bit_width, + group_size, + precision=scale_precision, ) q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( child.weight, diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index a5c23630be..954f29c0c2 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -219,7 +219,10 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) (s, zp) = get_group_qparams_symmetric( - child.weight, n_bit, config.group_size, precision=scale_precision, + child.weight, + n_bit, + config.group_size, + precision=config.scale_precision, ) from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, From 718e650b4b71e5b7a2bd68aaa920e0e66b7d83ca Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 25 Apr 2025 14:07:05 -0700 Subject: [PATCH 11/13] up --- torchao/quantization/GPTQ.py | 9 ++++----- torchao/quantization/qat/embedding.py | 1 + torchao/quantization/qat/linear.py | 7 +++++++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index bc8f6a8c56..54b5180f10 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -18,6 +18,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchao.dtypes.utils import is_device +from torchao.quantization.quant_api import _int8_asymm_per_token_quant from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_6, @@ -931,13 +932,11 @@ def linear_forward_8da4w( zeros, out_features, groupsize, - precision, + output_precision, ): # uses fp32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant # and activation_scale_dtype in QAT configs - x = per_token_dynamic_quant( - x, scale_dtype=torch.float32, zero_point_dtype=torch.float32 - ) + x = _int8_asymm_per_token_quant(x).dequantize() # TODO: verify and remove following reshape code # origin_x_size = x.size() # x = x.reshape(-1, origin_x_size[-1]) @@ -957,7 +956,7 @@ def linear_forward_8da4w( torch.int8, quant_min, quant_max, - output_dtype=precision, + output_dtype=output_precision, ) # x = x.to(torch.float16) diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 7f252c15d6..2770956a2c 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -240,6 +240,7 @@ def _convert_helper(self, module: torch.nn.Module): group_size, precision=scale_precision, ) + zp = zp.to(zero_point_precision) q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( child.weight, s, diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 954f29c0c2..4b0b160c53 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -224,6 +224,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): config.group_size, precision=config.scale_precision, ) + zp = zp.to(config.zero_point_precision) from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, ) @@ -261,6 +262,10 @@ class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): groupsize: the number of elements in each quantized group for weights precision: precision of weights scales_precision: precision of per group scales and zero points + + Note: we hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32). + To get an exact numerical match with Int8DynamicActivationInt4WeightConfig, users must use the same dtype for both the weights + and the scales. Here scales_precision refers specifically to the weight scales only, not the activation scales. """ def __init__( @@ -273,6 +278,8 @@ def __init__( precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: + # Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant, + # which is used in PTQ routines activation_config = _get_8da4w_activation_config(torch.float32) weight_config = _get_8da4w_weight_config(groupsize, scales_precision) super().__init__( From 2b077188cfd79640cf2b2b1dfae7932cfeda79f7 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 25 Apr 2025 14:18:04 -0700 Subject: [PATCH 12/13] up --- torchao/quantization/GPTQ.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 54b5180f10..94fccc7bf1 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -18,7 +18,6 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchao.dtypes.utils import is_device -from torchao.quantization.quant_api import _int8_asymm_per_token_quant from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_6, @@ -936,7 +935,12 @@ def linear_forward_8da4w( ): # uses fp32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant # and activation_scale_dtype in QAT configs - x = _int8_asymm_per_token_quant(x).dequantize() + # TODO: in future add ability to specify activation_scale_dtype to PTQ configs + # and enable similar change here + x = per_token_dynamic_quant( + x, scale_dtype=torch.float32, zero_point_dtype=torch.float32 + ) + # TODO: verify and remove following reshape code # origin_x_size = x.size() # x = x.reshape(-1, origin_x_size[-1]) From 5eab1cf92a5dce264e3f7ae4f246153f61a4a40b Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 25 Apr 2025 14:20:27 -0700 Subject: [PATCH 13/13] up --- test/quantization/test_qat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f3e6515b78..075671a043 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1474,7 +1474,6 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - @unittest.skip("Currently failing on sqnr") def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): """ Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces @@ -1493,7 +1492,9 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): torch.manual_seed(seed) x = m.example_inputs() - quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + quantizer = Int8DynActInt4WeightQATQuantizer( + groupsize=group_size, precision=dtype, scales_precision=dtype + ) prepared = quantizer.prepare(m) prepared_out = prepared(*x) converted = quantizer.convert(prepared)