diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 8c1301226f..b237860600 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,12 +1,17 @@ +import torch from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase from torch.testing._internal.common_utils import run_tests -from torchao.quantization import int8_weight_only +from torchao.quantization import int8_weight_only, float8_weight_only -class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): - pass +class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): + QUANT_METHOD_FN = staticmethod(int8_weight_only) +copy_tests(TorchAOTensorParallelTestCase, TestInt8woAffineQuantizedTensorParallel, "int8wo_tp") - -copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp") +# Run only on H100 +if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): + class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): + QUANT_METHOD_FN = staticmethod(float8_weight_only) + copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp") if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3dc632cd0e..c2c8e3c0b0 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1094,20 +1094,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) - if func is aten.clone.default: + elif func is aten.clone.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) - if func is aten.t.default: + elif func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ args[0].transposed = not args[0].transposed return return_and_correct_aliasing(func, args, kwargs, args[0]) - - raise NotImplementedError( - f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported" - ) + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + ) + elif dim == 1: + assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) + else: + raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + else: + raise NotImplementedError( + f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported" + ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -1644,6 +1655,28 @@ def _linear_fp8_act_fp8_weight_impl( use_fast_accum=scaled_mm_config.use_fast_accum, ).reshape(out_shape) +def _linear_fp_act_fp8_weight_check( + input_tensor: Union[torch.Tensor, AffineQuantizedTensor], + weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], + bias: Optional[torch.Tensor], +) -> bool: + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.is_floating_point() and + # weight is float8 quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) and + isinstance(weight_tensor.layout_type, Float8LayoutType) + and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) + ) + +def _linear_fp_act_fp8_weight_impl( + input_tensor: torch.Tensor, + weight_tensor: AffineQuantizedTensor, + bias: Optional[torch.Tensor], +): + return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): return ( @@ -1694,6 +1727,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), + (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),