diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index b237860600..8e6855a5df 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,7 +1,20 @@ import torch +import unittest from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase from torch.testing._internal.common_utils import run_tests -from torchao.quantization import int8_weight_only, float8_weight_only +from torch.testing._internal import common_utils +from torchao.quantization import int8_weight_only, float8_weight_only, float8_dynamic_activation_float8_weight +from torchao.quantization.observer import PerRow, PerTensor +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + NUM_DEVICES, +) +from torchao.quantization.quant_api import quantize_ +from torchao.dtypes import AffineQuantizedTensor +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): QUANT_METHOD_FN = staticmethod(int8_weight_only) @@ -13,5 +26,133 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): QUANT_METHOD_FN = staticmethod(float8_weight_only) copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp") +# Run only on H100 +if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): + class TestFloat8dqAffineQuantizedTensorParallel(DTensorTestBase): + """Basic test case for tensor subclasses + """ + COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] + TENSOR_SUBCLASS = AffineQuantizedTensor + QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_KWARGS = {} + + @staticmethod + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in column-wise fashion + """ + # Column-wise is wrt to A^T, so for A it is row-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_rows = orig_weight.size(0) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + @staticmethod + def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in row-wise fashion + """ + # Row-wise is wrt to A^T, so for A it is column-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_cols = orig_weight.size(1) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + def quantize(self, m: torch.nn.Module) -> torch.nn.Module: + """ + Quantize the model + """ + quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) + return m + + def _test_tp(self, dtype): + device = "cuda" + # To make sure different ranks create the same module + torch.manual_seed(5) + + class M(torch.nn.Module): + def __init__(self, in_features, out_features, **kwargs) -> None: + super().__init__(**kwargs) + self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Get rank and device + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") + + # Original model + proj_up = M(1024, 2048).to(device).to(dtype) + proj_dn = M(2048, 1024).to(device).to(dtype) + example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) + y = proj_dn(proj_up(example_input)) + # Quantize the model + up_quant = self.quantize(proj_up) + dn_quant = self.quantize(proj_dn) + y_q = dn_quant(up_quant(example_input)) + + mesh = self.build_device_mesh() + mesh.device_type = "cuda" + + # Shard the models + up_dist = self.colwise_shard(up_quant, mesh) + dn_dist = self.rowwise_shard(dn_quant, mesh) + + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + + y_d = dn_dist(up_dist(input_dtensor)) + + if not TORCH_VERSION_AT_LEAST_2_5: + # Need torch 2.5 to support compiled tensor parallelism + return + + up_compiled = torch.compile(up_dist) + y_up = up_compiled(input_dtensor) + dn_compiled = torch.compile(dn_dist) + y_dn = dn_compiled(y_up) + + class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_KWARGS = {"granularity": PerTensor()} + COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_KWARGS = {"granularity": PerRow()} + COMMON_DTYPES = [torch.bfloat16] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel) + common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel) if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index ff5deb7b08..d14a5dd17c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1062,9 +1062,12 @@ def __init__( def _apply_fn_to_data(self, fn): """ Applys a fn to all tensor components stored on this class""" - fn(self.float8_data) - fn(self.scale) - return self + return self.__class__( + fn(self.float8_data), + fn(self.scale), + self.transposed, + self._layout, + ) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -1107,12 +1110,21 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: + #TODO: scale replecation should be dependent on block size + if self.scale.ndim == 1: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + ) + elif self.scale.ndim == 0: + return return_and_correct_aliasing( + func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) + ) + else: + raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported") + elif dim == 1: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout) ) - elif dim == 1: - assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) else: raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") else: diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index f2eae07152..4c0f41b497 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -165,6 +165,20 @@ def _(func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.t) ) +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, LinearActivationQuantizedTensor( + func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func) + ) + +# this is needed for DTensor.from_local() and for flattening tensor +@implements(aten.view.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, LinearActivationQuantizedTensor(func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func) + ) + to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float if TORCH_VERSION_AT_LEAST_2_5: