From 8e80b3932b00dbf3edcf2ad1211488181a6023a2 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 10 Sep 2024 15:19:37 -0700 Subject: [PATCH 1/6] Float8 autoquant weight only --- torchao/quantization/autoquant.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 39482caf84..9e49a689ed 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType +from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( @@ -477,6 +477,15 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): def from_float(cls, weight): return weight +class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + """ + @classmethod + def from_float(cls, weight): + block_size = (1, weight.shape[1]) + return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType()) + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, From b8ab4ee474695991f4e31e722368a9b48b8a2e60 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 10 Sep 2024 15:19:37 -0700 Subject: [PATCH 2/6] Float8 autoquant weight only --- scripts/hf_eval.py | 2 +- scripts/prepare.sh | 0 torchao/dtypes/affine_quantized_tensor.py | 2 +- torchao/kernel/intmm.py | 5 ++++- torchao/quantization/autoquant.py | 7 ++++--- 5 files changed, 10 insertions(+), 6 deletions(-) mode change 100644 => 100755 scripts/hf_eval.py mode change 100644 => 100755 scripts/prepare.sh diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py old mode 100644 new mode 100755 index 5f008ee439..b4171102d2 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -89,7 +89,7 @@ def all_linear(mod, name): with torch.no_grad(): result = evaluate( HFLM( - pretrained=model.to(device), + pretrained=model, tokenizer=tokenizer, batch_size=batch_size, max_length=max_length), diff --git a/scripts/prepare.sh b/scripts/prepare.sh old mode 100644 new mode 100755 diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 418e75d039..025f36ec39 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -335,8 +335,8 @@ def from_hp_to_floatx( input_float: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - scale_dtype: Optional[torch.dtype], layout_type: LayoutType, + scale_dtype: Optional[torch.dtype] = None, ): if target_dtype in FP8_TYPES: diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 3005cb16a9..f13fb5bf55 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -69,7 +69,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: input = ( input.contiguous() ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except: + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 9e49a689ed..cde651fa07 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -479,12 +479,12 @@ def from_float(cls, weight): class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ - AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight """ @classmethod def from_float(cls, weight): block_size = (1, weight.shape[1]) - return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType()) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType()) # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ @@ -494,12 +494,13 @@ def from_float(cls, weight): # AQInt8WeightOnlyQuantizedLinearWeight3, # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, + AQFloat8WeightOnlyQuantizedLinearWeight, ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, - AQInt4G64WeightOnlyQuantizedLinearWeight + AQInt4G64WeightOnlyQuantizedLinearWeight, ] def _change_linears_to_autoquantizable(model, **kwargs): From 0ba6a2c5b4c4f132e9107a4e63e794bce32b4461 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 10 Sep 2024 15:19:37 -0700 Subject: [PATCH 3/6] Float8 autoquant weight only --- scripts/hf_eval.py | 2 +- scripts/prepare.sh | 0 torchao/quantization/autoquant.py | 13 ++++++++++--- 3 files changed, 11 insertions(+), 4 deletions(-) mode change 100755 => 100644 scripts/hf_eval.py mode change 100755 => 100644 scripts/prepare.sh diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py old mode 100755 new mode 100644 index b4171102d2..5f008ee439 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -89,7 +89,7 @@ def all_linear(mod, name): with torch.no_grad(): result = evaluate( HFLM( - pretrained=model, + pretrained=model.to(device), tokenizer=tokenizer, batch_size=batch_size, max_length=max_length), diff --git a/scripts/prepare.sh b/scripts/prepare.sh old mode 100755 new mode 100644 diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index cde651fa07..fa4ca36d85 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -479,12 +479,19 @@ def from_float(cls, weight): class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ - AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight + AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn """ + target_dtype: torch.dtype = torch.float8_e4m3fn + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias) + @classmethod def from_float(cls, weight): block_size = (1, weight.shape[1]) - return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType()) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ @@ -500,7 +507,7 @@ def from_float(cls, weight): DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, - AQInt4G64WeightOnlyQuantizedLinearWeight, + AQInt4G64WeightOnlyQuantizedLinearWeight ] def _change_linears_to_autoquantizable(model, **kwargs): From 988af92d12648cb9b5d82a5b0692bbea2ce69096 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 20 Sep 2024 15:01:46 -0700 Subject: [PATCH 4/6] Test cases --- test/integration/test_integration.py | 11 ++++++++++- torchao/quantization/autoquant.py | 8 +++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6a5ea8ef9d..8e047985c5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -72,7 +72,7 @@ AQInt8WeightOnlyQuantizedLinearWeight2, AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, - + AQFloat8WeightOnlyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -98,6 +98,7 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index fa4ca36d85..089add1d87 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -501,7 +501,6 @@ def from_float(cls, weight): # AQInt8WeightOnlyQuantizedLinearWeight3, # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, - AQFloat8WeightOnlyQuantizedLinearWeight, ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ @@ -510,6 +509,11 @@ def from_float(cls, weight): AQInt4G64WeightOnlyQuantizedLinearWeight ] +OTHER_AUTOQUANT_CLASS_LIST = [ + AQFloat8WeightOnlyQuantizedLinearWeight, +] + + def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the @@ -634,6 +638,8 @@ def autoquant( if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() + if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST: + assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9" # perform initial swap from linear weights # to AutoQuantizableLinearWeight From acb2afc4d8d24e4641ec4ee65c1b1568acfbbe94 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 24 Sep 2024 15:08:49 -0700 Subject: [PATCH 5/6] Review fixes --- test/kernel/test_autotuner.py | 20 ++++++++++++++++++++ torchao/kernel/intmm.py | 4 +++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 82fb117363..bdf24d81c4 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -16,6 +16,7 @@ logging.basicConfig(level=logging.INFO) +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -49,6 +50,25 @@ def test_int_mm(self, device, dtype): assert out32_2.dtype == out32_1.dtype torch.testing.assert_allclose(out32_1, out32_2) + @parameterized.expand( + [ + ("cuda", torch.bfloat16), + ("cuda", torch.float16), + ] + ) + @unittest.skipIf(not is_H100, "Need H100") + def test_int_mm_float8(self, device, dtype): + from torchao.kernel import intmm + + dtype = torch.bfloat16 + m, k, n = (128, 64, 16) + x = torch.randn(m, k, dtype=dtype, device=device) + w = torch.randn(n, k, dtype=dtype, device=device).t() + x_float8 = x.to(dtype=torch.float8_e4m3fn) + w_float8 = w.to(dtype=torch.float8_e4m3fn) + out32_1 = intmm.safe_int_mm(x_float8, w_float8) + assert out32_1.dtype == torch.int32 + @parameterized.expand( [ ("cuda", torch.bfloat16), diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index f13fb5bf55..b9620eb42e 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -71,7 +71,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: ) # (it seems the transpose makes cublas check the above j constraint on i) try: return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - except: + except Exception as e: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: From ce8ad0623d3d3179de61ff062b56e1ec4b0b7924 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 24 Sep 2024 15:08:49 -0700 Subject: [PATCH 6/6] Review fixes --- test/kernel/test_autotuner.py | 2 +- torchao/kernel/intmm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index bdf24d81c4..4ed0974172 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype): ("cuda", torch.float16), ] ) - @unittest.skipIf(not is_H100, "Need H100") + @unittest.skipIf(not is_H100, "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index b9620eb42e..81e7b19b15 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -71,7 +71,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: ) # (it seems the transpose makes cublas check the above j constraint on i) try: return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - except Exception as e: + except Exception: # fallback path, would run on H100 for float8 dtypes # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)