From b370c7cb91fa59833a90b2dbb73d83d769e1c2f4 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Feb 2025 14:52:00 -0800 Subject: [PATCH 1/3] Add int8 and fpx test to TensorParallel --- .../test_affine_quantized_tensor_parallel.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 3abb736f92..a3d793a128 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -12,7 +12,9 @@ from torchao.quantization import ( float8_dynamic_activation_float8_weight, float8_weight_only, + fpx_weight_only, int4_weight_only, + int8_dynamic_activation_int8_weight, int8_weight_only, ) from torchao.quantization.observer import PerRow, PerTensor @@ -166,9 +168,33 @@ def test_tp_gemlite(self, dtype): return self._test_tp(dtype) +class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight) + COMMON_DTYPES = [torch.bfloat16] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + +class TestFpxwoAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(fpx_weight_only) + COMMON_DTYPES = [torch.bfloat16] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel) +common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel) +common_utils.instantiate_parametrized_tests(TestFpxwoAffineQuantizedTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): From 12e2a7c0f04ea59185c11393b8e3aa4373882559 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Feb 2025 15:03:05 -0800 Subject: [PATCH 2/3] Lint fixes --- torchao/quantization/quant_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bbe9b1cb6b..7154957a21 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -450,7 +450,9 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs): +def _get_linear_subclass_inserter( + constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs +): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ From 30dff4fb44f0a54b8c4b97fd74b32a52b5dc5372 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Feb 2025 16:08:56 -0800 Subject: [PATCH 3/3] Remove fpx --- .../dtypes/test_affine_quantized_tensor_parallel.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index a3d793a128..76b6b74a3d 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -12,7 +12,6 @@ from torchao.quantization import ( float8_dynamic_activation_float8_weight, float8_weight_only, - fpx_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, @@ -179,22 +178,10 @@ def test_tp(self, dtype): return self._test_tp(dtype) -class TestFpxwoAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(fpx_weight_only) - COMMON_DTYPES = [torch.bfloat16] - - @common_utils.parametrize("dtype", COMMON_DTYPES) - @with_comms - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_tp(self, dtype): - return self._test_tp(dtype) - - common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel) common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel) -common_utils.instantiate_parametrized_tests(TestFpxwoAffineQuantizedTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):