diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 3abb736f92..76b6b74a3d 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -13,6 +13,7 @@ float8_dynamic_activation_float8_weight, float8_weight_only, int4_weight_only, + int8_dynamic_activation_int8_weight, int8_weight_only, ) from torchao.quantization.observer import PerRow, PerTensor @@ -166,9 +167,21 @@ def test_tp_gemlite(self, dtype): return self._test_tp(dtype) +class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight) + COMMON_DTYPES = [torch.bfloat16] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel) +common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):