From 3d42329a1535f2c071acfb3832915bf5411467f6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 16:05:59 +0800 Subject: [PATCH 01/41] initial commit --- .../prototype/quantized_training/README.md | 7 + .../prototype/quantized_training/__init__.py | 1 + .../prototype/quantized_training/subclass.py | 174 ++++++++++++++++++ 3 files changed, 182 insertions(+) create mode 100644 torchao/prototype/quantized_training/README.md create mode 100644 torchao/prototype/quantized_training/__init__.py create mode 100644 torchao/prototype/quantized_training/subclass.py diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md new file mode 100644 index 0000000000..43f832ffa6 --- /dev/null +++ b/torchao/prototype/quantized_training/README.md @@ -0,0 +1,7 @@ +# Quantized training + +This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from: +- Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] +- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] + +Currently we only support weight-only channel-wise INT8 symmetric quantization. diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py new file mode 100644 index 0000000000..ad849e0fff --- /dev/null +++ b/torchao/prototype/quantized_training/__init__.py @@ -0,0 +1 @@ +from .subclass import Int8QTLinearWeight, int8_weight_only_quantized_training diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py new file mode 100644 index 0000000000..7425af02a2 --- /dev/null +++ b/torchao/prototype/quantized_training/subclass.py @@ -0,0 +1,174 @@ +import torch +from torch import Tensor +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.utils import _implements, _dispatch__torch_function__, _dispatch__torch_dispatch__ +from torchao.quantization.quant_api import _get_linear_subclass_inserter + +aten = torch.ops.aten + + +# the main difference of this tensor subclass from AffineQuantizedTensor: +# 1. F.linear is differentiable i.e. backward is defined. +# 2. support stochastic rounding when casting from floating point. +class Int8QTLinearWeight(Tensor): + implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + + def __new__(cls, int_data, scale, requires_grad=False): + return Tensor._make_wrapper_subclass( + cls, + int_data.shape, + dtype=scale.dtype, + device=int_data.device, + requires_grad=requires_grad, + ) + + def __init__(self, int_data, scale, requires_grad=False): + """Create a symmetric quantized INT8 weight. This tensor will appear to have the same dtype + as `scale.dtype`. All in-place update ops will perform stochastic rounding. + """ + # NOTE: should scale always be FP32? + assert int_data.dtype is torch.int8 + self.int_data = int_data + self.scale = scale + + def __tensor_flatten__(self): + return ["int_data", "scale"], [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["int_data"], tensor_data_dict["scale"], *tensor_attributes) + + @staticmethod + def quantize(tensor: Tensor, stochastic_rounding: bool = False): + original_dtype = tensor.dtype + tensor = tensor.float() + + # absmax symmetric quantization + scale = tensor.abs().amax(-1) / 127 + tensor = tensor / scale.clip(1e-12).view(-1, 1) + + if stochastic_rounding: + # floor is required since .to(torch.int8) will convert 3.1 to 3 but -3.1 to -3 + tensor = (tensor + torch.rand_like(tensor)).floor() + else: + tensor = tensor.round() + + # NOTE: is clipping necessary? + tensor = tensor.clip(-128, 127).to(torch.int8) + return tensor, scale.to(original_dtype) + + @classmethod + def from_float(cls, tensor: Tensor): + """Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed.""" + int_data, scale = cls.quantize(tensor.detach()) + return cls(int_data, scale, requires_grad=tensor.requires_grad) + + def dequantize(self): + return self.int_data * self.scale.view(-1, 1) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(shape={tuple(self.shape)}, dtype={self.dtype}, device={self.device}, " + f"requires_grad={self.requires_grad})" + ) + + +@Int8QTLinearWeight.implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + return _Int8WeightOnlyLinear.apply(*args, **kwargs) + + +@Int8QTLinearWeight.implements(aten.detach.default) +def _(func, types, args, kwargs): + out = Int8QTLinearWeight(args[0].int_data, args[0].scale, requires_grad=False) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@Int8QTLinearWeight.implements(aten.clone.default) +def _(func, types, args, kwargs): + out = Int8QTLinearWeight( + args[0].int_data.clone(), + args[0].scale.clone(), + requires_grad=args[0].requires_grad, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@Int8QTLinearWeight.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + # we ignore memory_format in kwargs + # only perform dtype casting on scale, which determines the appearance dtype + device = kwargs.get("device", None) + dtype = kwargs.get("dtype", None) + out = Int8QTLinearWeight( + args[0].int_data.to(device=device), + args[0].scale.to(device=device, dtype=dtype), + requires_grad=args[0].requires_grad, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# to make training work with existing PyTorch optimizers, we return a normal tensor +@Int8QTLinearWeight.implements(aten.zeros_like.default) +def _(func, types, args, kwargs): + dtype = kwargs.get("dtype", args[0].dtype) + device = kwargs.get("device", args[0].device) + return torch.zeros(args[0].shape, dtype=dtype, device=device) + + +@Int8QTLinearWeight.implements([aten.sub.Tensor, aten.mul.Tensor]) +def _(func, types, args, kwargs): + args = [x.dequantize() if isinstance(x, Int8QTLinearWeight) else x for x in args] + return func(*args, **kwargs) + + +@Int8QTLinearWeight.implements(aten.copy_.default) +def _(func, types, args, kwargs): + if isinstance(args[0], Int8QTLinearWeight) and isinstance(args[1], Int8QTLinearWeight): + args[0].int_data.copy_(args[1].int_data) + args[0].scale.copy_(args[1].scale) + + elif isinstance(args[0], Int8QTLinearWeight): + int_data, scale = Int8QTLinearWeight.quantize(args[1], stochastic_rounding=True) + args[0].int_data.copy_(int_data) + args[0].scale.copy_(scale) + + else: + args[0].copy_(args[1].dequantize()) + + return args[0] + + +# this might be unnecessary +@Int8QTLinearWeight.implements(aten.addcdiv_.default) +def _(func, types, args, kwargs): + out = torch.addcdiv(args[0].dequantize(), *args[1:], **kwargs) + return args[0].copy_(out) + + +class _Int8WeightOnlyLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Tensor | None = None): + ctx.save_for_backward(input, weight) + ctx.bias = bias is not None + + # NOTE: we have to .T before .to(input.dtype) for torch.compile() mixed matmul to work + out = (input @ weight.int_data.T.to(input.dtype)) * weight.scale + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + + dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) + dweight = grad_output.flatten(0, -2).T @ input.flatten(0, -2) + dbias = grad_output.sum(0) if ctx.bias else None + return dinput, dweight, dbias + + +def int8_weight_only_quantized_training(): + return _get_linear_subclass_inserter(Int8QTLinearWeight.from_float) From eca170a08da5e18bda9a9381c2e6711093c7a59e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 20:16:55 +0800 Subject: [PATCH 02/41] add tests --- test/prototype/test_quantized_training.py | 77 +++++++++++++++++++ .../prototype/quantized_training/subclass.py | 13 +++- 2 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 test/prototype/test_quantized_training.py diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py new file mode 100644 index 0000000000..62b2dd3018 --- /dev/null +++ b/test/prototype/test_quantized_training.py @@ -0,0 +1,77 @@ +import copy + +import torch +import torch.nn.functional as F +from torch import nn +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training +from torchao.quantization.quant_api import quantize_ + + +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +class TestQuantizedTraining(TestCase): + @parametrize("device", _DEVICES) + def test_int8_stochastic_rounding(self, device): + x = torch.randn(32, device=device) + x_samples = x.view(1, -1).repeat(100_000, 1) + + x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True) + x_dequant_samples = x_int8 * x_scale.view(-1, 1) + x_dequant_mean = x_dequant_samples.mean(0) + + # a more rigorous test would be to do a hypothesis testing. + # due to the statistical nature, this assertion may still fail, though very rarely. + torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) + + @parametrize("device", _DEVICES) + @parametrize("leading_dims", [(), (2,), (2, 4)]) + @parametrize("bias", [False, True]) + def test_int8_linear_forward(self, leading_dims, bias, device): + embed_dim = 32 + + linear_fp32 = nn.Linear(embed_dim, embed_dim * 2, bias=bias, device=device) + linear_int8 = copy.deepcopy(linear_fp32) + quantize_(linear_int8, int8_weight_only_quantized_training()) + assert isinstance(linear_int8.weight, Int8QTLinearWeight) + + inputs = torch.randn(leading_dims + (embed_dim,), device=device) + out_fp32 = linear_fp32(inputs) + out_int8 = linear_int8(inputs) + torch.testing.assert_close(out_fp32, out_int8, atol=1e-2, rtol=1e-2) + + @parametrize("device", _DEVICES) + def test_int8_linear_backward(self, device): + bsize = 4 + embed_dim = 32 + n_classes = 10 + + model_fp32 = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 2, bias=False), + nn.GELU(), + nn.Linear(embed_dim * 2, n_classes), + ).to(device) + model_int8 = copy.deepcopy(model_fp32) + quantize_(model_int8, int8_weight_only_quantized_training()) + + inputs = torch.randn(bsize, embed_dim, device=device) + labels = torch.randint(n_classes, size=(bsize,), device=device) + F.cross_entropy(model_fp32(inputs), labels).backward() + F.cross_entropy(model_int8(inputs), labels).backward() + + for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()): + torch.testing.assert_close(p_fp32.grad, p_int8.grad, atol=1e-3, rtol=1e-2) + + +instantiate_parametrized_tests(TestQuantizedTraining) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 7425af02a2..d1667d66b2 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -1,8 +1,8 @@ import torch -from torch import Tensor +from torch import Tensor, nn from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.utils import _implements, _dispatch__torch_function__, _dispatch__torch_dispatch__ +from torchao.dtypes.utils import _dispatch__torch_dispatch__, _dispatch__torch_function__, _implements from torchao.quantization.quant_api import _get_linear_subclass_inserter aten = torch.ops.aten @@ -171,4 +171,11 @@ def backward(ctx, grad_output): def int8_weight_only_quantized_training(): - return _get_linear_subclass_inserter(Int8QTLinearWeight.from_float) + def apply_int8_linear_weight(linear: nn.Linear): + linear.weight = nn.Parameter( + Int8QTLinearWeight.from_float(linear.weight), + requires_grad=linear.weight.requires_grad, + ) + return linear + + return apply_int8_linear_weight From dd162a84ba96e72c9257884d88d94b4365f32239 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 20:38:45 +0800 Subject: [PATCH 03/41] add training --- test/prototype/test_quantized_training.py | 33 +++++++++++++++++++ torchao/prototype/low_bit_optim/__init__.py | 2 +- torchao/prototype/low_bit_optim/adamw.py | 19 +++++++++-- .../prototype/quantized_training/subclass.py | 7 +++- 4 files changed, 56 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 62b2dd3018..8e28c055f7 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -11,6 +11,7 @@ ) from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training +from torchao.prototype.low_bit_optim import AdamW from torchao.quantization.quant_api import quantize_ @@ -69,6 +70,38 @@ def test_int8_linear_backward(self, device): for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()): torch.testing.assert_close(p_fp32.grad, p_int8.grad, atol=1e-3, rtol=1e-2) + @parametrize("device", _DEVICES) + def test_int8_linear_training(self, device): + bsize = 4 + embed_dim = 32 + n_classes = 10 + + model_fp32 = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 2, bias=False), + nn.GELU(), + nn.Linear(embed_dim * 2, n_classes), + ).to(device) + model_int8 = copy.deepcopy(model_fp32) + quantize_(model_int8, int8_weight_only_quantized_training()) + + optim_fp32 = AdamW(model_fp32.parameters()) + optim_int8 = AdamW(model_int8.parameters()) + + for _ in range(2): + inputs = torch.randn(bsize, embed_dim, device=device) + labels = torch.randint(n_classes, size=(bsize,), device=device) + F.cross_entropy(model_fp32(inputs), labels).backward() + F.cross_entropy(model_int8(inputs), labels).backward() + + optim_fp32.step() + optim_fp32.zero_grad() + optim_int8.step() + optim_int8.zero_grad() + + with torch.no_grad(): + for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()): + torch.testing.assert_close(p_fp32, p_int8.dequantize(), atol=1e-2, rtol=1e-2) + instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index 01729bc6a3..c351f7b48b 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,3 +1,3 @@ from .adam import Adam8bit, Adam4bit, AdamFp8 -from .adamw import AdamW8bit, AdamW4bit, AdamWFp8 +from .adamw import AdamW, AdamW8bit, AdamW4bit, AdamWFp8 from .cpu_offload import CPUOffloadOptimizer diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index dbde91fdd2..4205aeca90 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -131,8 +131,6 @@ def single_param_adamw( weight_decay: float, eps: float, ): - p.mul_(1 - lr * weight_decay) - bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step @@ -150,8 +148,23 @@ def single_param_adamw( else: denom = (new_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) + # merge weight decay and param update in a single .add_() to make this work with quantized param step_size = lr / bias_correction1 - p.addcdiv_(new_exp_avg, denom, value=-step_size) + p.add_(-lr * weight_decay - step_size * new_exp_avg / denom) + + +class AdamW(_AdamW): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + ) -> None: + """AdamW optimizer that supports quantized training (parameter is quantized).""" + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=float("inf")) class AdamW8bit(_AdamW): diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index d1667d66b2..c53699115c 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -142,13 +142,18 @@ def _(func, types, args, kwargs): return args[0] -# this might be unnecessary @Int8QTLinearWeight.implements(aten.addcdiv_.default) def _(func, types, args, kwargs): out = torch.addcdiv(args[0].dequantize(), *args[1:], **kwargs) return args[0].copy_(out) +@Int8QTLinearWeight.implements(aten.add_.Tensor) +def _(func, types, args, kwargs): + out = torch.add(args[0].dequantize(), *args[1:], **kwargs) + return args[0].copy_(out) + + class _Int8WeightOnlyLinear(torch.autograd.Function): @staticmethod def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Tensor | None = None): From b286f5d6ef65ea665c05f8eced3a706f919c13bc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 21:07:37 +0800 Subject: [PATCH 04/41] support py3.9 --- torchao/prototype/quantized_training/subclass.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index c53699115c..6a36e47376 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -1,9 +1,11 @@ +from typing import Optional + import torch from torch import Tensor, nn from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.utils import _dispatch__torch_dispatch__, _dispatch__torch_function__, _implements -from torchao.quantization.quant_api import _get_linear_subclass_inserter + aten = torch.ops.aten @@ -156,7 +158,7 @@ def _(func, types, args, kwargs): class _Int8WeightOnlyLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Tensor | None = None): + def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tensor] = None): ctx.save_for_backward(input, weight) ctx.bias = bias is not None From 8a84acaa1a82fd600b472e2eca78992f1547355c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 21:26:01 +0800 Subject: [PATCH 05/41] skip test for torch<2.3 --- test/prototype/test_quantized_training.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 8e28c055f7..93b85bb7a5 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -1,5 +1,6 @@ import copy +import pytest import torch import torch.nn.functional as F from torch import nn @@ -10,9 +11,13 @@ run_tests, ) -from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training from torchao.prototype.low_bit_optim import AdamW +from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training from torchao.quantization.quant_api import quantize_ +from torchao.utils import TORCH_VERSION_AFTER_2_3 + +if not TORCH_VERSION_AFTER_2_3: + pytest.skip("Requires torch>=2.4") _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) From ea47c7daa2ef36388e30b480470b8fd73cfcc4ff Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 21:45:56 +0800 Subject: [PATCH 06/41] fix pytest --- test/prototype/test_quantized_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 93b85bb7a5..a15d0183d2 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -17,7 +17,7 @@ from torchao.utils import TORCH_VERSION_AFTER_2_3 if not TORCH_VERSION_AFTER_2_3: - pytest.skip("Requires torch>=2.4") + pytest.skip("Requires torch>=2.4", allow_module_level=True) _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) From f20486bb3a130f27ca52c50e965dc3232bca537f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 22:30:21 +0800 Subject: [PATCH 07/41] fix adamw --- test/prototype/test_low_bit_optim.py | 24 ++++++++++++++++++++++++ torchao/prototype/low_bit_optim/adamw.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index e69d0ed6fe..80cf9fc435 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -98,6 +98,30 @@ def test_optim_smoke(self, optim_name, dtype, device): optim.step() optim.zero_grad() + @parametrize("device", _DEVICES) + def test_optim_standard_correctness(self, device): + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model2 = copy.deepcopy(model1) + + optim1 = torch.optim.AdamW(model1.parameters()) + optim2 = low_bit_optim.AdamW(model2.parameters()) + + for _ in range(2): + x = torch.randn(4, 32, device=device) + + loss1 = model1(x).sum() + loss1.backward() + optim1.step() + optim1.zero_grad() + + loss2 = model2(x).sum() + loss2.backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 4205aeca90..456629fde7 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -150,7 +150,7 @@ def single_param_adamw( # merge weight decay and param update in a single .add_() to make this work with quantized param step_size = lr / bias_correction1 - p.add_(-lr * weight_decay - step_size * new_exp_avg / denom) + p.add_(-lr * weight_decay * p - step_size * new_exp_avg / denom) class AdamW(_AdamW): From 3415244213aceae8feffa2ddb159c6e6870e8a0c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 9 Aug 2024 22:35:01 +0800 Subject: [PATCH 08/41] add some FSDP ops --- .../prototype/quantized_training/subclass.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 6a36e47376..c5ad83789a 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -8,6 +8,8 @@ aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional # the main difference of this tensor subclass from AffineQuantizedTensor: @@ -156,6 +158,39 @@ def _(func, types, args, kwargs): return args[0].copy_(out) +# FSDP ops +@Int8QTLinearWeight.implements(aten.split.Tensor) +def _(func, types, args, kwargs): + if len(args) == 3 and args[2] != 0: + raise NotImplementedError("Int8QTLinearWeight only supports split at dim=0") + + int8_weight: Int8QTLinearWeight = args[0] + if int8_weight.ndim != 2: + raise NotImplementedError("Int8QTLinearWeight only supports split when ndim=2") + + int_data_list = func(int8_weight.int_data, *args[1:], **kwargs) + scale_list = func(int8_weight.scale, *args[1:], **kwargs) + return [ + Int8QTLinearWeight(int_data, scale, requires_grad=int8_weight.requires_grad) + for int_data, scale in zip(int_data_list, scale_list) + ] + + +@Int8QTLinearWeight.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, types, args, kwargs): + x: Int8QTLinearWeight = args[0] + return Int8QTLinearWeight( + func(x.int_data, *args[1:], **kwargs), + func(x.scale, *args[1:], **kwargs), + requires_grad=x.requires_grad, + ) + + class _Int8WeightOnlyLinear(torch.autograd.Function): @staticmethod def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tensor] = None): From 5d0e65861c381bf1ffa429005f93591c689945bb Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 10:07:57 +0800 Subject: [PATCH 09/41] add more fsdp ops --- .../prototype/quantized_training/subclass.py | 60 ++++++++++++++----- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index c5ad83789a..908859f042 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -20,6 +20,7 @@ class Int8QTLinearWeight(Tensor): __torch_function__ = classmethod(_dispatch__torch_function__) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + @staticmethod def __new__(cls, int_data, scale, requires_grad=False): return Tensor._make_wrapper_subclass( cls, @@ -35,6 +36,8 @@ def __init__(self, int_data, scale, requires_grad=False): """ # NOTE: should scale always be FP32? assert int_data.dtype is torch.int8 + assert int_data.ndim == 2 + assert scale.ndim == 1 self.int_data = int_data self.scale = scale @@ -91,11 +94,12 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, out) -@Int8QTLinearWeight.implements(aten.clone.default) +@Int8QTLinearWeight.implements([aten.clone.default, aten.slice.Tensor]) def _(func, types, args, kwargs): + # will error out if try to slice 2nd dim out = Int8QTLinearWeight( - args[0].int_data.clone(), - args[0].scale.clone(), + func(args[0].int_data, *args[1:], **kwargs), + func(args[0].scale, *args[1:], **kwargs), requires_grad=args[0].requires_grad, ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -165,23 +169,49 @@ def _(func, types, args, kwargs): raise NotImplementedError("Int8QTLinearWeight only supports split at dim=0") int8_weight: Int8QTLinearWeight = args[0] - if int8_weight.ndim != 2: - raise NotImplementedError("Int8QTLinearWeight only supports split when ndim=2") - int_data_list = func(int8_weight.int_data, *args[1:], **kwargs) scale_list = func(int8_weight.scale, *args[1:], **kwargs) - return [ - Int8QTLinearWeight(int_data, scale, requires_grad=int8_weight.requires_grad) - for int_data, scale in zip(int_data_list, scale_list) + + # requires_grad must be False here + out = [ + Int8QTLinearWeight(int_data, scale, requires_grad=False) for int_data, scale in zip(int_data_list, scale_list) ] + return out + +@Int8QTLinearWeight.implements(aten.new_zeros.default) +def _(func, types, args, kwargs): + size = args[1] + if len(size) != 2: + raise NotImplementedError + + # ignore other kwargs. NOTE: is requires_grad needed? + device = kwargs.get("device") + dtype = kwargs.get("dtype") + int_data = torch.zeros(size, device=device, dtype=torch.int8) + scale = torch.zeros(size[0], device=device, dtype=dtype) + return Int8QTLinearWeight(int_data, scale) -@Int8QTLinearWeight.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) + +@Int8QTLinearWeight.implements(aten.view.default) +def _(func, types, args, kwargs): + # don't do anything. workaround for FSDP2. might give unexpected results + out = Int8QTLinearWeight( + args[0].int_data, + args[0].scale, + requires_grad=args[0].requires_grad, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@Int8QTLinearWeight.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x: Int8QTLinearWeight = args[0] return Int8QTLinearWeight( From d7534768750b25fdc99af49ada2fc9ce0a83d18a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 10:22:26 +0800 Subject: [PATCH 10/41] more ops --- .../prototype/quantized_training/subclass.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 908859f042..bfdf15b289 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional, Tuple import torch from torch import Tensor, nn @@ -82,6 +82,20 @@ def __repr__(self): f"requires_grad={self.requires_grad})" ) + def fsdp_pre_all_gather(self, mesh): + return (self.int_data, self.scale), None + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Tensor] = None, + ): + int_data, scale = all_gather_outputs + return Int8QTLinearWeight(int_data, scale), all_gather_outputs + @Int8QTLinearWeight.implements(torch.nn.functional.linear) def _(func, types, args, kwargs): @@ -193,9 +207,9 @@ def _(func, types, args, kwargs): return Int8QTLinearWeight(int_data, scale) -@Int8QTLinearWeight.implements(aten.view.default) +# don't do anything. workaround for FSDP2. might give unexpected or wrong results. +@Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default]) def _(func, types, args, kwargs): - # don't do anything. workaround for FSDP2. might give unexpected results out = Int8QTLinearWeight( args[0].int_data, args[0].scale, From 9c778007b0c449e67a983ceceb786720c6f8500c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 11:03:55 +0800 Subject: [PATCH 11/41] add benchmark script --- benchmarks/benchmark_int8_qt.py | 143 ++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 benchmarks/benchmark_int8_qt.py diff --git a/benchmarks/benchmark_int8_qt.py b/benchmarks/benchmark_int8_qt.py new file mode 100644 index 0000000000..f1bb7becb3 --- /dev/null +++ b/benchmarks/benchmark_int8_qt.py @@ -0,0 +1,143 @@ +# pre-train a mini Llama2 on TinyStories with INT8 quantized training +# pip install transformers sentencepiece wandb +# +# BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --step 10_000 +# INT8 QT: python benchamrks/benchmark_int8_qt.py --seed 2024 --step 10_000 --quantize int8_weight_only + +import os + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +import argparse +from pathlib import Path + +import numpy as np +import torch +import wandb +from tqdm import tqdm +from transformers import LlamaConfig, LlamaForCausalLM + +from torchao.prototype.low_bit_optim import AdamW +from torchao.prototype.quantized_training import int8_weight_only_quantized_training +from torchao.quantization.quant_api import quantize_ + + +def get_loss(model: LlamaForCausalLM, batch: torch.Tensor): + return model(batch, labels=batch).loss + + +def get_tinystories(): + save_path = Path("tinystories.bin") + + if not save_path.exists(): + import sentencepiece as spm + from huggingface_hub import hf_hub_download + + tokenizer_path = hf_hub_download("meta-llama/Llama-2-7b", "tokenizer.model") + tokenizer = spm.SentencePieceProcessor(tokenizer_path) + assert tokenizer.vocab_size() < (1 << 16) # make sure we can use uint16 + + # do everything in memory. we have enough RAM + filepath = hf_hub_download( + "roneneldan/TinyStories", + "TinyStoriesV2-GPT4-train.txt", + repo_type="dataset", + ) + stories = open(filepath).read().split("\n<|endoftext|>\n") + + tokens_list = [] + chunk_size = 10_000 + for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"): + chunk = stories[i : min(i + chunk_size, len(stories))] + tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4)) + + total_size = sum(len(x) for x in tokens_list) + mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size) + i = 0 + for tokens in tokens_list: + mmap_tokens[i : i + len(tokens)] = tokens + i += len(tokens) + mmap_tokens.flush() + + tokens = np.memmap(save_path, dtype=np.uint16, mode="r") + return torch.from_numpy(tokens) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # default config is 470M + parser.add_argument("--d_model", type=int, default=1024) + parser.add_argument("--depth", type=int, default=24) + parser.add_argument("--ffn_size", type=int, default=4096) + parser.add_argument("--head_dim", type=int, default=64) + + parser.add_argument("--quantize") + parser.add_argument("--activation_checkpointing", action="store_true") + + parser.add_argument("--n_steps", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=2048) + + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--weight_decay", type=float, default=1e-2) + + parser.add_argument("--project", default="int8_quantized_training") + parser.add_argument("--run_name") + parser.add_argument("--seed", type=int) + args = parser.parse_args() + + if args.seed is not None: + torch.manual_seed(args.seed) + + config = LlamaConfig( + hidden_size=args.d_model, + intermediate_size=args.ffn_size, + num_hidden_layers=args.depth, + num_attention_heads=args.d_model // args.head_dim, + max_position_embeddings=args.seq_len, + use_cache=False, + ) + model = LlamaForCausalLM(config).bfloat16().cuda() + if args.activation_checkpointing: + model.gradient_checkpointing_enable() + if args.quantize == "int8_weight_only": + quantize_(model, int8_weight_only_quantized_training()) + elif args.quantize is not None: + raise ValueError(f"Unsupported quantize={args.quantize}") + print(f"No. of params: {sum(p.numel() for p in model.parameters())}") + print(f"No. of buffers: {sum(p.numel() for p in model.buffers())}") + + optim = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + data = get_tinystories().cuda() + run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) + + step = 0 + log_interval = 50 + pbar = tqdm(total=args.n_steps, dynamic_ncols=True) + model.train() + + while step < args.n_steps: + # randomly select a continuous chunk, then reshape it + idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() + batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() + + loss = torch.compile(get_loss)(model, batch) + loss.backward() + + if step % log_interval == 0: + log_dict = dict( + loss=loss.item(), + lr=optim.param_groups[0]["lr"], + max_memory_allocated=torch.cuda.max_memory_allocated(), + ) + run.log(log_dict, step=step) + pbar.set_postfix(loss=log_dict["loss"]) + + optim.step() + optim.zero_grad() + + step += 1 + pbar.update() + + run.finish() From 158eb6108eb1d77c36c11d24e40918f8a9e17397 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 11:14:50 +0800 Subject: [PATCH 12/41] some organisation --- .../prototype/quantized_training/subclass.py | 76 +++++++++---------- 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index bfdf15b289..c8989b7495 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -97,24 +97,44 @@ def fsdp_post_all_gather( return Int8QTLinearWeight(int_data, scale), all_gather_outputs -@Int8QTLinearWeight.implements(torch.nn.functional.linear) -def _(func, types, args, kwargs): - return _Int8WeightOnlyLinear.apply(*args, **kwargs) +class _Int8WeightOnlyLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tensor] = None): + ctx.save_for_backward(input, weight) + ctx.bias = bias is not None + + # NOTE: we have to .T before .to(input.dtype) for torch.compile() mixed matmul to work + out = (input @ weight.int_data.T.to(input.dtype)) * weight.scale + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + + dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) + dweight = grad_output.flatten(0, -2).T @ input.flatten(0, -2) + dbias = grad_output.sum(0) if ctx.bias else None + return dinput, dweight, dbias -@Int8QTLinearWeight.implements(aten.detach.default) +@Int8QTLinearWeight.implements(torch.nn.functional.linear) def _(func, types, args, kwargs): - out = Int8QTLinearWeight(args[0].int_data, args[0].scale, requires_grad=False) - return return_and_correct_aliasing(func, args, kwargs, out) + return _Int8WeightOnlyLinear.apply(*args, **kwargs) -@Int8QTLinearWeight.implements([aten.clone.default, aten.slice.Tensor]) +@Int8QTLinearWeight.implements( + [ + aten.detach.default, + aten.clone.default, + aten.slice.Tensor, + ] +) def _(func, types, args, kwargs): # will error out if try to slice 2nd dim out = Int8QTLinearWeight( func(args[0].int_data, *args[1:], **kwargs), func(args[0].scale, *args[1:], **kwargs), - requires_grad=args[0].requires_grad, ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -141,6 +161,7 @@ def _(func, types, args, kwargs): return torch.zeros(args[0].shape, dtype=dtype, device=device) +# out-of-place math ops always return plain tensor @Int8QTLinearWeight.implements([aten.sub.Tensor, aten.mul.Tensor]) def _(func, types, args, kwargs): args = [x.dequantize() if isinstance(x, Int8QTLinearWeight) else x for x in args] @@ -164,16 +185,11 @@ def _(func, types, args, kwargs): return args[0] -@Int8QTLinearWeight.implements(aten.addcdiv_.default) +@Int8QTLinearWeight.implements([aten.addcdiv_.default, aten.add_.Tensor]) def _(func, types, args, kwargs): - out = torch.addcdiv(args[0].dequantize(), *args[1:], **kwargs) - return args[0].copy_(out) - - -@Int8QTLinearWeight.implements(aten.add_.Tensor) -def _(func, types, args, kwargs): - out = torch.add(args[0].dequantize(), *args[1:], **kwargs) - return args[0].copy_(out) + original = args[0] + out = func(args[0].dequantize(), *args[1:], **kwargs) + return original.copy_(out) # FSDP ops @@ -186,10 +202,7 @@ def _(func, types, args, kwargs): int_data_list = func(int8_weight.int_data, *args[1:], **kwargs) scale_list = func(int8_weight.scale, *args[1:], **kwargs) - # requires_grad must be False here - out = [ - Int8QTLinearWeight(int_data, scale, requires_grad=False) for int_data, scale in zip(int_data_list, scale_list) - ] + out = [Int8QTLinearWeight(int_data, scale) for int_data, scale in zip(int_data_list, scale_list)] return out @@ -235,27 +248,6 @@ def _(func, types, args, kwargs): ) -class _Int8WeightOnlyLinear(torch.autograd.Function): - @staticmethod - def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tensor] = None): - ctx.save_for_backward(input, weight) - ctx.bias = bias is not None - - # NOTE: we have to .T before .to(input.dtype) for torch.compile() mixed matmul to work - out = (input @ weight.int_data.T.to(input.dtype)) * weight.scale - out = out + bias if bias is not None else out - return out - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - - dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) - dweight = grad_output.flatten(0, -2).T @ input.flatten(0, -2) - dbias = grad_output.sum(0) if ctx.bias else None - return dinput, dweight, dbias - - def int8_weight_only_quantized_training(): def apply_int8_linear_weight(linear: nn.Linear): linear.weight = nn.Parameter( From db0290fb985ac402f14013f10815f3e80ac9f01a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 11:49:33 +0800 Subject: [PATCH 13/41] add FSDP test --- benchmarks/benchmark_int8_qt.py | 7 +- test/prototype/test_quantized_training.py | 76 +++++++++++++++++-- .../prototype/quantized_training/subclass.py | 8 +- 3 files changed, 79 insertions(+), 12 deletions(-) diff --git a/benchmarks/benchmark_int8_qt.py b/benchmarks/benchmark_int8_qt.py index f1bb7becb3..a7595b118c 100644 --- a/benchmarks/benchmark_int8_qt.py +++ b/benchmarks/benchmark_int8_qt.py @@ -1,6 +1,6 @@ # pre-train a mini Llama2 on TinyStories with INT8 quantized training # pip install transformers sentencepiece wandb -# +# # BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --step 10_000 # INT8 QT: python benchamrks/benchmark_int8_qt.py --seed 2024 --step 10_000 --quantize int8_weight_only @@ -17,7 +17,7 @@ from tqdm import tqdm from transformers import LlamaConfig, LlamaForCausalLM -from torchao.prototype.low_bit_optim import AdamW +from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import int8_weight_only_quantized_training from torchao.quantization.quant_api import quantize_ @@ -78,6 +78,7 @@ def get_tinystories(): parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--seq_len", type=int, default=2048) + parser.add_argument("--optim", default="AdamW") parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--weight_decay", type=float, default=1e-2) @@ -107,7 +108,7 @@ def get_tinystories(): print(f"No. of params: {sum(p.numel() for p in model.parameters())}") print(f"No. of buffers: {sum(p.numel() for p in model.buffers())}") - optim = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optim = getattr(low_bit_optim, args.oprim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) data = get_tinystories().cuda() run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index a15d0183d2..228bfa8398 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -4,12 +4,9 @@ import torch import torch.nn.functional as F from torch import nn -from torch.testing._internal.common_utils import ( - TestCase, - instantiate_parametrized_tests, - parametrize, - run_tests, -) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests from torchao.prototype.low_bit_optim import AdamW from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training @@ -108,6 +105,73 @@ def test_int8_linear_training(self, device): torch.testing.assert_close(p_fp32, p_int8.dequantize(), atol=1e-2, rtol=1e-2) +class TestFSDP2(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(2) + def test_fsdp2(self): + self.run_subtests( + {"activation_checkpointing": [False, True]}, + self._test_fsdp2, + ) + + def _test_fsdp2(self, activation_checkpointing): + import torch.distributed as dist + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + apply_activation_checkpointing, + ) + from torch.distributed.fsdp.wrap import ModuleWrapPolicy + from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock + + batch_size = 3 + vocab_size = 1024 + seq_len = 64 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + dim=1024, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + torch.manual_seed(42) + base_model = Transformer(model_args).cuda() + quantize_(base_model, int8_weight_only_quantized_training()) + if activation_checkpointing: + policy = ModuleWrapPolicy({TransformerBlock}) + apply_activation_checkpointing(base_model, auto_wrap_policy=policy) + base_optim = AdamW(base_model.parameters(), lr=1e-2) + + fsdp_model = copy.deepcopy(base_model) + for m in fsdp_model.modules(): + cls_to_shard = CheckpointWrapper if activation_checkpointing else TransformerBlock + if isinstance(m, cls_to_shard): + fully_shard(m) + fully_shard(fsdp_model) + fsdp_optim = AdamW(fsdp_model.parameters(), lr=1e-2) + + torch.manual_seed(42 + self.rank + 1) + for iter_idx in range(5): + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + fsdp_loss = fsdp_model(inp).sum() + fsdp_loss.backward() + fsdp_optim.step() + + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + base_loss = base_model(inp).sum() + base_loss.backward() + for param in base_model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) + base_optim.step() + self.assertEqual(fsdp_loss, base_loss) + + instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index c8989b7495..02caa76b0a 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -139,6 +139,7 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, out) +# TODO: also handle non_blocking kwarg? @Int8QTLinearWeight.implements(aten._to_copy.default) def _(func, types, args, kwargs): # we ignore memory_format in kwargs @@ -168,6 +169,7 @@ def _(func, types, args, kwargs): return func(*args, **kwargs) +# TODO: handle non_blocking kwarg? @Int8QTLinearWeight.implements(aten.copy_.default) def _(func, types, args, kwargs): if isinstance(args[0], Int8QTLinearWeight) and isinstance(args[1], Int8QTLinearWeight): @@ -212,9 +214,9 @@ def _(func, types, args, kwargs): if len(size) != 2: raise NotImplementedError - # ignore other kwargs. NOTE: is requires_grad needed? - device = kwargs.get("device") - dtype = kwargs.get("dtype") + # TODO: handle pin_memory kwarg? + device = kwargs.get("device", args[0].device) + dtype = kwargs.get("dtype", args[0].dtype) int_data = torch.zeros(size, device=device, dtype=torch.int8) scale = torch.zeros(size[0], device=device, dtype=dtype) return Int8QTLinearWeight(int_data, scale) From 1c32b78e1c2de1a2e0450597e69a6bdc7a528d9c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 13:21:23 +0800 Subject: [PATCH 14/41] clean up --- .../prototype/quantized_training/subclass.py | 43 ++++++------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 02caa76b0a..ef7ff0fbce 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -127,7 +127,12 @@ def _(func, types, args, kwargs): [ aten.detach.default, aten.clone.default, + # FSDP ops aten.slice.Tensor, + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, ] ) def _(func, types, args, kwargs): @@ -139,17 +144,15 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, out) -# TODO: also handle non_blocking kwarg? @Int8QTLinearWeight.implements(aten._to_copy.default) def _(func, types, args, kwargs): - # we ignore memory_format in kwargs # only perform dtype casting on scale, which determines the appearance dtype + # TODO: handle non_blocking kwarg? device = kwargs.get("device", None) dtype = kwargs.get("dtype", None) out = Int8QTLinearWeight( args[0].int_data.to(device=device), args[0].scale.to(device=device, dtype=dtype), - requires_grad=args[0].requires_grad, ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -169,20 +172,19 @@ def _(func, types, args, kwargs): return func(*args, **kwargs) -# TODO: handle non_blocking kwarg? @Int8QTLinearWeight.implements(aten.copy_.default) def _(func, types, args, kwargs): if isinstance(args[0], Int8QTLinearWeight) and isinstance(args[1], Int8QTLinearWeight): - args[0].int_data.copy_(args[1].int_data) - args[0].scale.copy_(args[1].scale) + args[0].int_data.copy_(args[1].int_data, **kwargs) + args[0].scale.copy_(args[1].scale, **kwargs) elif isinstance(args[0], Int8QTLinearWeight): int_data, scale = Int8QTLinearWeight.quantize(args[1], stochastic_rounding=True) - args[0].int_data.copy_(int_data) - args[0].scale.copy_(scale) + args[0].int_data.copy_(int_data, **kwargs) + args[0].scale.copy_(scale, **kwargs) else: - args[0].copy_(args[1].dequantize()) + args[0].copy_(args[1].dequantize(), **kwargs) return args[0] @@ -225,31 +227,10 @@ def _(func, types, args, kwargs): # don't do anything. workaround for FSDP2. might give unexpected or wrong results. @Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default]) def _(func, types, args, kwargs): - out = Int8QTLinearWeight( - args[0].int_data, - args[0].scale, - requires_grad=args[0].requires_grad, - ) + out = Int8QTLinearWeight(args[0].int_data, args[0].scale) return return_and_correct_aliasing(func, args, kwargs, out) -@Int8QTLinearWeight.implements( - [ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, - ] -) -def _(func, types, args, kwargs): - x: Int8QTLinearWeight = args[0] - return Int8QTLinearWeight( - func(x.int_data, *args[1:], **kwargs), - func(x.scale, *args[1:], **kwargs), - requires_grad=x.requires_grad, - ) - - def int8_weight_only_quantized_training(): def apply_int8_linear_weight(linear: nn.Linear): linear.weight = nn.Parameter( From ff69121f464a97ae96008c71f206784c17af8294 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 14:01:37 +0800 Subject: [PATCH 15/41] update FSDP test --- test/prototype/test_quantized_training.py | 36 +++++++++++-------- .../prototype/quantized_training/subclass.py | 4 +-- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 228bfa8398..cf76db2567 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -113,22 +113,22 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fsdp2(self): self.run_subtests( - {"activation_checkpointing": [False, True]}, + { + "activation_checkpointing": [False, True], + # "compile_layer": [False, True], + }, self._test_fsdp2, ) - def _test_fsdp2(self, activation_checkpointing): + def _test_fsdp2(self, activation_checkpointing, compile_layer): import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointWrapper, - apply_activation_checkpointing, - ) + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock batch_size = 3 - vocab_size = 1024 + vocab_size = 32 seq_len = 64 model_args = ModelArgs( n_layers=3, @@ -144,14 +144,19 @@ def _test_fsdp2(self, activation_checkpointing): if activation_checkpointing: policy = ModuleWrapPolicy({TransformerBlock}) apply_activation_checkpointing(base_model, auto_wrap_policy=policy) - base_optim = AdamW(base_model.parameters(), lr=1e-2) - fsdp_model = copy.deepcopy(base_model) - for m in fsdp_model.modules(): - cls_to_shard = CheckpointWrapper if activation_checkpointing else TransformerBlock - if isinstance(m, cls_to_shard): - fully_shard(m) + + if compile_layer: + for layer in base_model.layers: + layer.compile() + + for layer in fsdp_model.layers: + if compile_layer: + layer.compile() + fully_shard(layer) fully_shard(fsdp_model) + + base_optim = AdamW(base_model.parameters(), lr=1e-2) fsdp_optim = AdamW(fsdp_model.parameters(), lr=1e-2) torch.manual_seed(42 + self.rank + 1) @@ -169,7 +174,10 @@ def _test_fsdp2(self, activation_checkpointing): if param.grad is not None: dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) base_optim.step() - self.assertEqual(fsdp_loss, base_loss) + + # due to stochastic rounding, use a pretty large tolerance here + rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() + assert rel_error < 0.05, rel_error instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index ef7ff0fbce..06d1f3bc53 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -21,7 +21,7 @@ class Int8QTLinearWeight(Tensor): __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @staticmethod - def __new__(cls, int_data, scale, requires_grad=False): + def __new__(cls, int_data: Tensor, scale: Tensor, requires_grad: bool = False): return Tensor._make_wrapper_subclass( cls, int_data.shape, @@ -30,7 +30,7 @@ def __new__(cls, int_data, scale, requires_grad=False): requires_grad=requires_grad, ) - def __init__(self, int_data, scale, requires_grad=False): + def __init__(self, int_data: Tensor, scale: Tensor, requires_grad: bool = False): """Create a symmetric quantized INT8 weight. This tensor will appear to have the same dtype as `scale.dtype`. All in-place update ops will perform stochastic rounding. """ From 45342ba7ccb4fe018ebe974e92fa57c4000fb89e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 14:15:50 +0800 Subject: [PATCH 16/41] add compile test (things are crashing) --- test/prototype/test_quantized_training.py | 27 ++++++++++++++++++- .../prototype/quantized_training/subclass.py | 2 +- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index cf76db2567..cbd5f76832 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -34,9 +34,9 @@ def test_int8_stochastic_rounding(self, device): # due to the statistical nature, this assertion may still fail, though very rarely. torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) - @parametrize("device", _DEVICES) @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) + @parametrize("device", _DEVICES) def test_int8_linear_forward(self, leading_dims, bias, device): embed_dim = 32 @@ -72,6 +72,31 @@ def test_int8_linear_backward(self, device): for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()): torch.testing.assert_close(p_fp32.grad, p_int8.grad, atol=1e-3, rtol=1e-2) + @parametrize("bias", [False, True]) + @parametrize("device", _DEVICES) + def test_int8_linear_compile(self, bias, device): + bsize = 4 + embed_dim = 32 + n_classes = 10 + + linear = nn.Linear(embed_dim, n_classes, bias=bias, device=device) + quantize_(linear, int8_weight_only_quantized_training()) + linear_compiled = copy.deepcopy(linear) + linear_compiled.compile() + + inputs = torch.randn((bsize, embed_dim,), device=device) + labels = torch.randint(n_classes, size=(bsize,), device=device) + + out = linear(inputs) + out_compiled = linear_compiled(inputs) + torch.testing.assert_close(out, out_compiled, atol=1e-2, rtol=1e-2) + + F.cross_entropy(out, labels).backward() + F.cross_entropy(out_compiled, labels).backward() + + for p, p_compiled in zip(linear.parameters(), linear_compiled.parameters()): + torch.testing.assert_close(p.grad, p_compiled.grad) + @parametrize("device", _DEVICES) def test_int8_linear_training(self, device): bsize = 4 diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 06d1f3bc53..264fc78aab 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -113,7 +113,7 @@ def backward(ctx, grad_output): input, weight = ctx.saved_tensors dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) - dweight = grad_output.flatten(0, -2).T @ input.flatten(0, -2) + dweight = grad_output.view(-1, weight.shape[0]).T @ input.view(-1, weight.shape[1]) dbias = grad_output.sum(0) if ctx.bias else None return dinput, dweight, dbias From f1587a274d298ce7c6b463e502601bdc4a950948 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 14:22:08 +0800 Subject: [PATCH 17/41] fix bias --- torchao/prototype/quantized_training/subclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 264fc78aab..1e66168987 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -114,7 +114,7 @@ def backward(ctx, grad_output): dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) dweight = grad_output.view(-1, weight.shape[0]).T @ input.view(-1, weight.shape[1]) - dbias = grad_output.sum(0) if ctx.bias else None + dbias = grad_output.view(-1, weight.shape[0]).sum(0) if ctx.bias else None return dinput, dweight, dbias From 7f9102a8279af39276a704e1544d045bf578c615 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 16:18:50 +0800 Subject: [PATCH 18/41] substantial update to tests --- test/prototype/test_quantized_training.py | 129 +++++++++--------- .../prototype/quantized_training/subclass.py | 8 +- 2 files changed, 72 insertions(+), 65 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index cbd5f76832..237977bbb4 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -37,68 +37,71 @@ def test_int8_stochastic_rounding(self, device): @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear_forward(self, leading_dims, bias, device): + def test_int8_linear(self, leading_dims, bias, device): embed_dim = 32 - linear_fp32 = nn.Linear(embed_dim, embed_dim * 2, bias=bias, device=device) + linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) linear_int8 = copy.deepcopy(linear_fp32) quantize_(linear_int8, int8_weight_only_quantized_training()) - assert isinstance(linear_int8.weight, Int8QTLinearWeight) + linear_fp32.weight.data = linear_int8.weight.data.dequantize() - inputs = torch.randn(leading_dims + (embed_dim,), device=device) - out_fp32 = linear_fp32(inputs) - out_int8 = linear_int8(inputs) - torch.testing.assert_close(out_fp32, out_int8, atol=1e-2, rtol=1e-2) + input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device) + input_int8 = input_fp32.clone() + input_fp32.requires_grad_(True) + input_int8.requires_grad_(True) - @parametrize("device", _DEVICES) - def test_int8_linear_backward(self, device): - bsize = 4 - embed_dim = 32 - n_classes = 10 - - model_fp32 = nn.Sequential( - nn.Linear(embed_dim, embed_dim * 2, bias=False), - nn.GELU(), - nn.Linear(embed_dim * 2, n_classes), - ).to(device) - model_int8 = copy.deepcopy(model_fp32) - quantize_(model_int8, int8_weight_only_quantized_training()) + # quantize_() will set torch.set_float32_matmul_precision("high"), thus failing accuracy check on CUDA. + # manually override it here. + torch.set_float32_matmul_precision("highest") - inputs = torch.randn(bsize, embed_dim, device=device) - labels = torch.randint(n_classes, size=(bsize,), device=device) - F.cross_entropy(model_fp32(inputs), labels).backward() - F.cross_entropy(model_int8(inputs), labels).backward() + out_fp32 = linear_fp32(input_fp32) + out_int8 = linear_int8(input_int8) + torch.testing.assert_close(out_fp32, out_int8) - for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()): - torch.testing.assert_close(p_fp32.grad, p_int8.grad, atol=1e-3, rtol=1e-2) + grad = torch.randn(leading_dims + (embed_dim,), device=device) + out_fp32.backward(grad) + out_int8.backward(grad) + torch.testing.assert_close(input_fp32.grad, input_int8.grad) + torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad) + if bias: + torch.testing.assert_close(linear_fp32.bias.grad, linear_int8.bias.grad) + @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear_compile(self, bias, device): - bsize = 4 - embed_dim = 32 - n_classes = 10 + def test_int8_linear_compile(self, leading_dims, bias, device): + torch._dynamo.reset() + embed_dim = 128 - linear = nn.Linear(embed_dim, n_classes, bias=bias, device=device) - quantize_(linear, int8_weight_only_quantized_training()) - linear_compiled = copy.deepcopy(linear) + linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) + quantize_(linear_eager, int8_weight_only_quantized_training()) + linear_compiled = copy.deepcopy(linear_eager) linear_compiled.compile() - inputs = torch.randn((bsize, embed_dim,), device=device) - labels = torch.randint(n_classes, size=(bsize,), device=device) + input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10 + input_compiled = input_eager.clone() + input_eager.requires_grad_(True) + input_compiled.requires_grad_(True) - out = linear(inputs) - out_compiled = linear_compiled(inputs) - torch.testing.assert_close(out, out_compiled, atol=1e-2, rtol=1e-2) + # quantize_() will set torch.set_float32_matmul_precision("high"), which causes segfault. + # manually override it here. + torch.set_float32_matmul_precision("highest") - F.cross_entropy(out, labels).backward() - F.cross_entropy(out_compiled, labels).backward() + out_eager = linear_eager(input_eager) + out_compiled = linear_compiled(input_compiled) + torch.testing.assert_close(out_eager, out_compiled) - for p, p_compiled in zip(linear.parameters(), linear_compiled.parameters()): - torch.testing.assert_close(p.grad, p_compiled.grad) + grad = torch.randn(leading_dims + (embed_dim,), device=device) + out_eager.backward(grad) + out_compiled.backward(grad) + torch.testing.assert_close(input_eager.grad, input_compiled.grad) + torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad) + if bias: + torch.testing.assert_close(linear_eager.bias.grad, linear_compiled.bias.grad) + @parametrize("compile", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear_training(self, device): + def test_int8_linear_training(self, compile, device): bsize = 4 embed_dim = 32 n_classes = 10 @@ -111,24 +114,33 @@ def test_int8_linear_training(self, device): model_int8 = copy.deepcopy(model_fp32) quantize_(model_int8, int8_weight_only_quantized_training()) + if compile: + model_fp32.compile() + model_int8.compile() + optim_fp32 = AdamW(model_fp32.parameters()) optim_int8 = AdamW(model_int8.parameters()) - for _ in range(2): + # prevent segfault with torch.compile() + torch.set_float32_matmul_precision("highest") + + for _ in range(5): inputs = torch.randn(bsize, embed_dim, device=device) labels = torch.randint(n_classes, size=(bsize,), device=device) - F.cross_entropy(model_fp32(inputs), labels).backward() - F.cross_entropy(model_int8(inputs), labels).backward() + loss_fp32 = F.cross_entropy(model_fp32(inputs), labels) + loss_int8 = F.cross_entropy(model_int8(inputs), labels) + + rel_error = abs(loss_int8.item() - loss_fp32.item()) / abs(loss_fp32.item()) + assert rel_error < 2e-3, rel_error + loss_fp32.backward() optim_fp32.step() optim_fp32.zero_grad() + + loss_int8.backward() optim_int8.step() optim_int8.zero_grad() - with torch.no_grad(): - for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()): - torch.testing.assert_close(p_fp32, p_int8.dequantize(), atol=1e-2, rtol=1e-2) - class TestFSDP2(FSDPTest): @property @@ -138,19 +150,14 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fsdp2(self): self.run_subtests( - { - "activation_checkpointing": [False, True], - # "compile_layer": [False, True], - }, + {"compile_layer": [False, True]}, self._test_fsdp2, ) - def _test_fsdp2(self, activation_checkpointing, compile_layer): + def _test_fsdp2(self, compile_layer): import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing - from torch.distributed.fsdp.wrap import ModuleWrapPolicy - from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock + from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer batch_size = 3 vocab_size = 32 @@ -166,9 +173,6 @@ def _test_fsdp2(self, activation_checkpointing, compile_layer): torch.manual_seed(42) base_model = Transformer(model_args).cuda() quantize_(base_model, int8_weight_only_quantized_training()) - if activation_checkpointing: - policy = ModuleWrapPolicy({TransformerBlock}) - apply_activation_checkpointing(base_model, auto_wrap_policy=policy) fsdp_model = copy.deepcopy(base_model) if compile_layer: @@ -184,6 +188,9 @@ def _test_fsdp2(self, activation_checkpointing, compile_layer): base_optim = AdamW(base_model.parameters(), lr=1e-2) fsdp_optim = AdamW(fsdp_model.parameters(), lr=1e-2) + # prevent segfault with torch.compile() + torch.set_float32_matmul_precision("highest") + torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 1e66168987..43085f9630 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -112,10 +112,10 @@ def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tenso def backward(ctx, grad_output): input, weight = ctx.saved_tensors - dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) - dweight = grad_output.view(-1, weight.shape[0]).T @ input.view(-1, weight.shape[1]) - dbias = grad_output.view(-1, weight.shape[0]).sum(0) if ctx.bias else None - return dinput, dweight, dbias + grad_input = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) + grad_weight = grad_output.view(-1, weight.shape[0]).T @ input.view(-1, weight.shape[1]) + grad_bias = grad_output.view(-1, weight.shape[0]).sum(0) if ctx.bias else None + return grad_input, grad_weight, grad_bias @Int8QTLinearWeight.implements(torch.nn.functional.linear) From 042833010b8c402d30a26878db1f0a5c07eaab32 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 16:28:35 +0800 Subject: [PATCH 19/41] fix compile for FSDP --- test/prototype/test_quantized_training.py | 2 ++ torchao/prototype/quantized_training/subclass.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 237977bbb4..3e6022d206 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -102,6 +102,7 @@ def test_int8_linear_compile(self, leading_dims, bias, device): @parametrize("compile", [False, True]) @parametrize("device", _DEVICES) def test_int8_linear_training(self, compile, device): + torch._dynamo.reset() bsize = 4 embed_dim = 32 n_classes = 10 @@ -159,6 +160,7 @@ def _test_fsdp2(self, compile_layer): from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer + torch._dynamo.reset() batch_size = 3 vocab_size = 32 seq_len = 64 diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 43085f9630..711f7887dc 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -21,6 +21,7 @@ class Int8QTLinearWeight(Tensor): __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @staticmethod + @torch._dynamo.disable def __new__(cls, int_data: Tensor, scale: Tensor, requires_grad: bool = False): return Tensor._make_wrapper_subclass( cls, @@ -30,6 +31,7 @@ def __new__(cls, int_data: Tensor, scale: Tensor, requires_grad: bool = False): requires_grad=requires_grad, ) + @torch._dynamo.disable def __init__(self, int_data: Tensor, scale: Tensor, requires_grad: bool = False): """Create a symmetric quantized INT8 weight. This tensor will appear to have the same dtype as `scale.dtype`. All in-place update ops will perform stochastic rounding. From 001422cd6cbe3c30a8086b5aa7f909a2f567996e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 17:51:54 +0800 Subject: [PATCH 20/41] update readme. rename file --- .../prototype/quantized_training/README.md | 23 +++++++++++++++++++ .../prototype/quantized_training/__init__.py | 2 +- .../{subclass.py => int8.py} | 0 3 files changed, 24 insertions(+), 1 deletion(-) rename torchao/prototype/quantized_training/{subclass.py => int8.py} (100%) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 43f832ffa6..6d2c465b64 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -5,3 +5,26 @@ This folder contains experimental work on quantized training (QT). The main diff - AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] Currently we only support weight-only channel-wise INT8 symmetric quantization. + +## INT8 weight only + +In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization scheme `[-127, 127]`. During forward and backward, the weights are upcast to activations' dtype (e.g. BF16). Therefore, gradients are also in activations' dtype. In the optimizer step, we use **stochastic rounding** to update the weights, ensuring small weight updates can still change the weights. + +Usage + +```python +from torchao.prototype.quantized_training import int8_weight_only_quantized_training +from torchao.prototype.low_bit_optim import AdamW +from torchao.quantization.quant_api import quantize_ + +model = ... +quantize_(model, int8_weight_only_quantized_training()) + +optim = AdamW(model.parameters(), lr=3e-4) +``` + +It is recommended to use optimizers from `torchao.prototype.low_bit_optim` for quantized training, because they can automatically generate efficient fused optimizer kernel for `dequant->optimizer_step->quant` thanks to `torch.compile()`. + +[`benchmarks/benchmark_int8_qt.py`](../../../benchbenchmarks/benchmark_int8_qt.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. + +See [#644](https://github.com/pytorch/ao/pull/644) for some early results. diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index ad849e0fff..6c7f8eb9b1 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -1 +1 @@ -from .subclass import Int8QTLinearWeight, int8_weight_only_quantized_training +from .int8 import Int8QTLinearWeight, int8_weight_only_quantized_training diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/int8.py similarity index 100% rename from torchao/prototype/quantized_training/subclass.py rename to torchao/prototype/quantized_training/int8.py From 2eb2787ef5cc3c62541a12fa91d6e0cd4b2cbbf5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 19:09:13 +0800 Subject: [PATCH 21/41] speed up CI --- test/prototype/test_quantized_training.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 3e6022d206..7b03ae3170 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -165,9 +165,9 @@ def _test_fsdp2(self, compile_layer): vocab_size = 32 seq_len = 64 model_args = ModelArgs( - n_layers=3, - n_heads=4, - dim=1024, + n_layers=2, + n_heads=2, + dim=128, vocab_size=vocab_size, max_seq_len=seq_len, dropout_p=0, @@ -193,6 +193,10 @@ def _test_fsdp2(self, compile_layer): # prevent segfault with torch.compile() torch.set_float32_matmul_precision("highest") + # turn off these flags (set by quantize_()) to speed up compile time in CI + torch._inductor.config.coordinate_descent_tuning = False + torch._inductor.config.coordinate_descent_check_all_directions = False + torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") From d39caba90df91e9f6ede6eaeb30fc21612df4612 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 19:11:49 +0800 Subject: [PATCH 22/41] fix typo --- torchao/prototype/quantized_training/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 6d2c465b64..b24c2532b8 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -25,6 +25,6 @@ optim = AdamW(model.parameters(), lr=3e-4) It is recommended to use optimizers from `torchao.prototype.low_bit_optim` for quantized training, because they can automatically generate efficient fused optimizer kernel for `dequant->optimizer_step->quant` thanks to `torch.compile()`. -[`benchmarks/benchmark_int8_qt.py`](../../../benchbenchmarks/benchmark_int8_qt.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. +[`benchmarks/benchmark_int8_qt.py`](../../../benchmarks/benchmark_int8_qt.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. See [#644](https://github.com/pytorch/ao/pull/644) for some early results. From de6aa25aee4c8adb88d4bc1acf33887dece68a60 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 19:12:51 +0800 Subject: [PATCH 23/41] fix typo --- benchmarks/benchmark_int8_qt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_int8_qt.py b/benchmarks/benchmark_int8_qt.py index a7595b118c..91bc2cc3c5 100644 --- a/benchmarks/benchmark_int8_qt.py +++ b/benchmarks/benchmark_int8_qt.py @@ -1,8 +1,8 @@ # pre-train a mini Llama2 on TinyStories with INT8 quantized training # pip install transformers sentencepiece wandb # -# BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --step 10_000 -# INT8 QT: python benchamrks/benchmark_int8_qt.py --seed 2024 --step 10_000 --quantize int8_weight_only +# BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 +# INT8 QT: python benchamrks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only import os @@ -108,7 +108,7 @@ def get_tinystories(): print(f"No. of params: {sum(p.numel() for p in model.parameters())}") print(f"No. of buffers: {sum(p.numel() for p in model.buffers())}") - optim = getattr(low_bit_optim, args.oprim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) data = get_tinystories().cuda() run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) From adbe47df865072e3628cfcca27649b0407c3c46b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 19:33:51 +0800 Subject: [PATCH 24/41] typos. unset some dynamo flags --- benchmarks/benchmark_int8_qt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_int8_qt.py b/benchmarks/benchmark_int8_qt.py index 91bc2cc3c5..5d4298ee07 100644 --- a/benchmarks/benchmark_int8_qt.py +++ b/benchmarks/benchmark_int8_qt.py @@ -2,7 +2,7 @@ # pip install transformers sentencepiece wandb # # BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 -# INT8 QT: python benchamrks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only +# INT8 QT: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only import os @@ -105,8 +105,12 @@ def get_tinystories(): quantize_(model, int8_weight_only_quantized_training()) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") - print(f"No. of params: {sum(p.numel() for p in model.parameters())}") - print(f"No. of buffers: {sum(p.numel() for p in model.buffers())}") + print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") + print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") + + # turn off these flags (set by quantize_()) to speed up compile time + torch._inductor.config.coordinate_descent_tuning = False + torch._inductor.config.coordinate_descent_check_all_directions = False optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) From 3fdf7766cc25d7ce2eef4009b477775dc4936716 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 10 Aug 2024 21:07:02 +0800 Subject: [PATCH 25/41] update readme --- torchao/prototype/quantized_training/README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index b24c2532b8..f60dd50f88 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -4,11 +4,17 @@ This folder contains experimental work on quantized training (QT). The main diff - Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] - AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] +Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. + +There are 2 main benefits for training in this way: +1. Reduce memory footprint. Also reduce communication bandwidth in distributed setting. +2. What you train is what you serve (WYTIWYS). + Currently we only support weight-only channel-wise INT8 symmetric quantization. ## INT8 weight only -In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization scheme `[-127, 127]`. During forward and backward, the weights are upcast to activations' dtype (e.g. BF16). Therefore, gradients are also in activations' dtype. In the optimizer step, we use **stochastic rounding** to update the weights, ensuring small weight updates can still change the weights. +In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization `[-127, 127]`. In the forward and backward pass, the weights are upcast to activations' dtype (e.g. BF16). Therefore, their gradients are also in activations' dtype. Usage @@ -28,3 +34,9 @@ It is recommended to use optimizers from `torchao.prototype.low_bit_optim` for q [`benchmarks/benchmark_int8_qt.py`](../../../benchmarks/benchmark_int8_qt.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. See [#644](https://github.com/pytorch/ao/pull/644) for some early results. + +## Future ideas + +- INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores. +- INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels). +- FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. From ea0ee4f38d8d09dc676cd822a25e469245db558e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 11 Aug 2024 09:51:33 +0800 Subject: [PATCH 26/41] remove requires_grad, since it is unnecessary --- torchao/prototype/quantized_training/int8.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 711f7887dc..325b24daf7 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -22,17 +22,16 @@ class Int8QTLinearWeight(Tensor): @staticmethod @torch._dynamo.disable - def __new__(cls, int_data: Tensor, scale: Tensor, requires_grad: bool = False): + def __new__(cls, int_data: Tensor, scale: Tensor): return Tensor._make_wrapper_subclass( cls, int_data.shape, dtype=scale.dtype, device=int_data.device, - requires_grad=requires_grad, ) @torch._dynamo.disable - def __init__(self, int_data: Tensor, scale: Tensor, requires_grad: bool = False): + def __init__(self, int_data: Tensor, scale: Tensor): """Create a symmetric quantized INT8 weight. This tensor will appear to have the same dtype as `scale.dtype`. All in-place update ops will perform stochastic rounding. """ @@ -51,6 +50,7 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No return cls(tensor_data_dict["int_data"], tensor_data_dict["scale"], *tensor_attributes) @staticmethod + @torch.no_grad() def quantize(tensor: Tensor, stochastic_rounding: bool = False): original_dtype = tensor.dtype tensor = tensor.float() @@ -71,9 +71,13 @@ def quantize(tensor: Tensor, stochastic_rounding: bool = False): @classmethod def from_float(cls, tensor: Tensor): - """Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed.""" + """Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed. + This function is not differentiable. + """ int_data, scale = cls.quantize(tensor.detach()) - return cls(int_data, scale, requires_grad=tensor.requires_grad) + out = cls(int_data, scale) + out.requires_grad_(tensor.requires_grad) + return out def dequantize(self): return self.int_data * self.scale.view(-1, 1) From 36d0e1a3a5431fb30132a7e5aca914268cf4ded8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 11 Aug 2024 09:53:18 +0800 Subject: [PATCH 27/41] remove note --- torchao/prototype/quantized_training/int8.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 325b24daf7..17c6a66bf1 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -60,7 +60,6 @@ def quantize(tensor: Tensor, stochastic_rounding: bool = False): tensor = tensor / scale.clip(1e-12).view(-1, 1) if stochastic_rounding: - # floor is required since .to(torch.int8) will convert 3.1 to 3 but -3.1 to -3 tensor = (tensor + torch.rand_like(tensor)).floor() else: tensor = tensor.round() From 6bc7621cf3b1b0131f5f7fa49e3c3e175faf3b59 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 13 Aug 2024 21:14:57 +0800 Subject: [PATCH 28/41] don't set inductor flags --- benchmarks/benchmark_int8_qt.py | 6 +--- test/prototype/test_quantized_training.py | 37 +++++++++-------------- 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/benchmarks/benchmark_int8_qt.py b/benchmarks/benchmark_int8_qt.py index 5d4298ee07..8c33461658 100644 --- a/benchmarks/benchmark_int8_qt.py +++ b/benchmarks/benchmark_int8_qt.py @@ -102,16 +102,12 @@ def get_tinystories(): if args.activation_checkpointing: model.gradient_checkpointing_enable() if args.quantize == "int8_weight_only": - quantize_(model, int8_weight_only_quantized_training()) + quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") - # turn off these flags (set by quantize_()) to speed up compile time - torch._inductor.config.coordinate_descent_tuning = False - torch._inductor.config.coordinate_descent_check_all_directions = False - optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) data = get_tinystories().cuda() diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 7b03ae3170..a923463a97 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -19,6 +19,14 @@ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +# using TF32 will cause mixed mm to segfault with triton backend +# fixed by https://github.com/pytorch/pytorch/pull/133173 but just set here to be safe +# also required for correctness check +torch.set_float32_matmul_precision("highest") + +# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI +# and make sure TF32 is not used (see above). + class TestQuantizedTraining(TestCase): @parametrize("device", _DEVICES) @@ -42,7 +50,7 @@ def test_int8_linear(self, leading_dims, bias, device): linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) linear_int8 = copy.deepcopy(linear_fp32) - quantize_(linear_int8, int8_weight_only_quantized_training()) + quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False) linear_fp32.weight.data = linear_int8.weight.data.dequantize() input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device) @@ -50,14 +58,12 @@ def test_int8_linear(self, leading_dims, bias, device): input_fp32.requires_grad_(True) input_int8.requires_grad_(True) - # quantize_() will set torch.set_float32_matmul_precision("high"), thus failing accuracy check on CUDA. - # manually override it here. - torch.set_float32_matmul_precision("highest") - + # test forward out_fp32 = linear_fp32(input_fp32) out_int8 = linear_int8(input_int8) torch.testing.assert_close(out_fp32, out_int8) + # test backward grad = torch.randn(leading_dims + (embed_dim,), device=device) out_fp32.backward(grad) out_int8.backward(grad) @@ -74,7 +80,7 @@ def test_int8_linear_compile(self, leading_dims, bias, device): embed_dim = 128 linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) - quantize_(linear_eager, int8_weight_only_quantized_training()) + quantize_(linear_eager, int8_weight_only_quantized_training(), set_inductor_config=False) linear_compiled = copy.deepcopy(linear_eager) linear_compiled.compile() @@ -83,10 +89,6 @@ def test_int8_linear_compile(self, leading_dims, bias, device): input_eager.requires_grad_(True) input_compiled.requires_grad_(True) - # quantize_() will set torch.set_float32_matmul_precision("high"), which causes segfault. - # manually override it here. - torch.set_float32_matmul_precision("highest") - out_eager = linear_eager(input_eager) out_compiled = linear_compiled(input_compiled) torch.testing.assert_close(out_eager, out_compiled) @@ -113,7 +115,8 @@ def test_int8_linear_training(self, compile, device): nn.Linear(embed_dim * 2, n_classes), ).to(device) model_int8 = copy.deepcopy(model_fp32) - quantize_(model_int8, int8_weight_only_quantized_training()) + # don't set inductor flags to speed up CI time + quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False) if compile: model_fp32.compile() @@ -122,9 +125,6 @@ def test_int8_linear_training(self, compile, device): optim_fp32 = AdamW(model_fp32.parameters()) optim_int8 = AdamW(model_int8.parameters()) - # prevent segfault with torch.compile() - torch.set_float32_matmul_precision("highest") - for _ in range(5): inputs = torch.randn(bsize, embed_dim, device=device) labels = torch.randint(n_classes, size=(bsize,), device=device) @@ -174,7 +174,7 @@ def _test_fsdp2(self, compile_layer): ) torch.manual_seed(42) base_model = Transformer(model_args).cuda() - quantize_(base_model, int8_weight_only_quantized_training()) + quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False) fsdp_model = copy.deepcopy(base_model) if compile_layer: @@ -190,13 +190,6 @@ def _test_fsdp2(self, compile_layer): base_optim = AdamW(base_model.parameters(), lr=1e-2) fsdp_optim = AdamW(fsdp_model.parameters(), lr=1e-2) - # prevent segfault with torch.compile() - torch.set_float32_matmul_precision("highest") - - # turn off these flags (set by quantize_()) to speed up compile time in CI - torch._inductor.config.coordinate_descent_tuning = False - torch._inductor.config.coordinate_descent_check_all_directions = False - torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") From 6646c0b57530b59c399fb35684d3058848e4f08d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 13 Aug 2024 21:36:11 +0800 Subject: [PATCH 29/41] rename --- .../pretrain_llama2.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename benchmarks/{benchmark_int8_qt.py => quantized_training/pretrain_llama2.py} (95%) diff --git a/benchmarks/benchmark_int8_qt.py b/benchmarks/quantized_training/pretrain_llama2.py similarity index 95% rename from benchmarks/benchmark_int8_qt.py rename to benchmarks/quantized_training/pretrain_llama2.py index 8c33461658..aa3cf90b86 100644 --- a/benchmarks/benchmark_int8_qt.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -1,8 +1,8 @@ # pre-train a mini Llama2 on TinyStories with INT8 quantized training # pip install transformers sentencepiece wandb # -# BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 -# INT8 QT: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only +# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 +# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only import os From 00e25cf499dc93a46794f4573c4eb89bc72cae75 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 13 Aug 2024 21:37:11 +0800 Subject: [PATCH 30/41] update README --- torchao/prototype/quantized_training/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index f60dd50f88..7d1820a9e4 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -8,7 +8,7 @@ Typically, low-precision weights cannot be trained directly due to quantization There are 2 main benefits for training in this way: 1. Reduce memory footprint. Also reduce communication bandwidth in distributed setting. -2. What you train is what you serve (WYTIWYS). +2. What you train is what you serve ([WYTIWYS](https://github.com/google/aqt?tab=readme-ov-file#features)). Currently we only support weight-only channel-wise INT8 symmetric quantization. From 927a6d180cfd3162e722ff92a3c57785be484e1d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 13 Aug 2024 21:51:56 +0800 Subject: [PATCH 31/41] rename optimizer --- .../quantized_training/pretrain_llama2.py | 3 +++ test/prototype/test_low_bit_optim.py | 24 ------------------- test/prototype/test_quantized_training.py | 10 ++++---- torchao/prototype/low_bit_optim/__init__.py | 2 +- torchao/prototype/low_bit_optim/adam.py | 8 +++---- torchao/prototype/low_bit_optim/adamw.py | 13 +++++----- 6 files changed, 20 insertions(+), 40 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index aa3cf90b86..9ce8503931 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -108,6 +108,9 @@ def get_tinystories(): print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") + # only use optimizers from torchao.prototype.low_bit_optim to support quantized training + if args.optim == "AdamW": + args.optim = "_AdamW" optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) data = get_tinystories().cuda() diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 6bc4ce8110..28dd377408 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -98,30 +98,6 @@ def test_optim_smoke(self, optim_name, dtype, device): optim.step() optim.zero_grad() - @parametrize("device", _DEVICES) - def test_optim_standard_correctness(self, device): - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - model2 = copy.deepcopy(model1) - - optim1 = torch.optim.AdamW(model1.parameters()) - optim2 = low_bit_optim.AdamW(model2.parameters()) - - for _ in range(2): - x = torch.randn(4, 32, device=device) - - loss1 = model1(x).sum() - loss1.backward() - optim1.step() - optim1.zero_grad() - - loss2 = model2(x).sum() - loss2.backward() - optim2.step() - optim2.zero_grad() - - for p1, p2 in zip(model1.parameters(), model2.parameters()): - torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) - @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index a923463a97..9c8904d9ec 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -8,7 +8,7 @@ from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests -from torchao.prototype.low_bit_optim import AdamW +from torchao.prototype.low_bit_optim import _AdamW from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training from torchao.quantization.quant_api import quantize_ from torchao.utils import TORCH_VERSION_AFTER_2_3 @@ -122,8 +122,8 @@ def test_int8_linear_training(self, compile, device): model_fp32.compile() model_int8.compile() - optim_fp32 = AdamW(model_fp32.parameters()) - optim_int8 = AdamW(model_int8.parameters()) + optim_fp32 = _AdamW(model_fp32.parameters()) + optim_int8 = _AdamW(model_int8.parameters()) for _ in range(5): inputs = torch.randn(bsize, embed_dim, device=device) @@ -187,8 +187,8 @@ def _test_fsdp2(self, compile_layer): fully_shard(layer) fully_shard(fsdp_model) - base_optim = AdamW(base_model.parameters(), lr=1e-2) - fsdp_optim = AdamW(fsdp_model.parameters(), lr=1e-2) + base_optim = _AdamW(base_model.parameters(), lr=1e-2) + fsdp_optim = _AdamW(fsdp_model.parameters(), lr=1e-2) torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index c351f7b48b..5e9cc50c67 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,3 +1,3 @@ from .adam import Adam8bit, Adam4bit, AdamFp8 -from .adamw import AdamW, AdamW8bit, AdamW4bit, AdamWFp8 +from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8 from .cpu_offload import CPUOffloadOptimizer diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 47a99c06dc..a5425e9840 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -10,7 +10,7 @@ from .subclass_fp8 import OptimStateFp8 -class _Adam(Optimizer): +class _AdamBase(Optimizer): def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -155,7 +155,7 @@ def single_param_adam( p.addcdiv_(new_exp_avg, denom, value=-step_size) -class Adam8bit(_Adam): +class Adam8bit(_AdamBase): def __init__( self, params, @@ -174,7 +174,7 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): return OptimState8bit.zeros(p.shape, signed, block_size, p.device) -class Adam4bit(_Adam): +class Adam4bit(_AdamBase): def __init__( self, params, @@ -233,7 +233,7 @@ def step(self, closure=None): return loss -class AdamFp8(_Adam): +class AdamFp8(_AdamBase): def __init__( self, params, diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 456629fde7..9d1df8e6c8 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -10,7 +10,7 @@ from .subclass_fp8 import OptimStateFp8 -class _AdamW(Optimizer): +class _AdamWBase(Optimizer): def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -153,7 +153,7 @@ def single_param_adamw( p.add_(-lr * weight_decay * p - step_size * new_exp_avg / denom) -class AdamW(_AdamW): +class _AdamW(_AdamWBase): def __init__( self, params, @@ -163,11 +163,12 @@ def __init__( weight_decay=1e-2, amsgrad=False, ) -> None: - """AdamW optimizer that supports quantized training (parameter is quantized).""" + """AdamW optimizer that supports quantized training (parameter is quantized). This optimizer should + only be used with torchao's quantized training.""" super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=float("inf")) -class AdamW8bit(_AdamW): +class AdamW8bit(_AdamWBase): def __init__( self, params, @@ -186,7 +187,7 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): return OptimState8bit.zeros(p.shape, signed, block_size, p.device) -class AdamW4bit(_AdamW): +class AdamW4bit(_AdamWBase): def __init__( self, params, @@ -245,7 +246,7 @@ def step(self, closure=None): return loss -class AdamWFp8(_AdamW): +class AdamWFp8(_AdamWBase): def __init__( self, params, From de49e8be60413ac392bef4c8e6650f3289c79fb2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 14 Aug 2024 09:20:42 +0800 Subject: [PATCH 32/41] update benchmark script --- benchmarks/quantized_training/pretrain_llama2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 9ce8503931..431e2e18df 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -1,8 +1,8 @@ # pre-train a mini Llama2 on TinyStories with INT8 quantized training # pip install transformers sentencepiece wandb # -# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 -# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only +# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile +# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only import os @@ -73,6 +73,7 @@ def get_tinystories(): parser.add_argument("--quantize") parser.add_argument("--activation_checkpointing", action="store_true") + parser.add_argument("--compile", action="store_true") parser.add_argument("--n_steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=4) @@ -133,7 +134,8 @@ def get_tinystories(): log_dict = dict( loss=loss.item(), lr=optim.param_groups[0]["lr"], - max_memory_allocated=torch.cuda.max_memory_allocated(), + max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, + max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9, ) run.log(log_dict, step=step) pbar.set_postfix(loss=log_dict["loss"]) From f80ac97d5d7668343694d19333be72ac1ffc5cd2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 14 Aug 2024 14:08:55 +0800 Subject: [PATCH 33/41] make compile explicit --- benchmarks/quantized_training/pretrain_llama2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 431e2e18df..344a3a71af 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -121,13 +121,14 @@ def get_tinystories(): log_interval = 50 pbar = tqdm(total=args.n_steps, dynamic_ncols=True) model.train() + _get_loss = torch.compile(get_loss) if args.compile else get_loss while step < args.n_steps: # randomly select a continuous chunk, then reshape it idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() - loss = torch.compile(get_loss)(model, batch) + loss = _get_loss(model, batch) loss.backward() if step % log_interval == 0: From e375c3d4530b635f70d9a1b6d9e28b9d654d3be9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 14 Aug 2024 14:35:03 +0800 Subject: [PATCH 34/41] update docs --- torchao/prototype/quantized_training/README.md | 12 +++++++----- torchao/prototype/quantized_training/int8.py | 16 ++++++++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 7d1820a9e4..fb74255a3f 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -6,6 +6,8 @@ This folder contains experimental work on quantized training (QT). The main diff Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. +In precise terms, the probability of rounding up is `x - ⌊x⌋`. Note that when the value is exactly an integer value, the probability of rounding up is zero. + There are 2 main benefits for training in this way: 1. Reduce memory footprint. Also reduce communication bandwidth in distributed setting. 2. What you train is what you serve ([WYTIWYS](https://github.com/google/aqt?tab=readme-ov-file#features)). @@ -20,18 +22,18 @@ Usage ```python from torchao.prototype.quantized_training import int8_weight_only_quantized_training -from torchao.prototype.low_bit_optim import AdamW +from torchao.prototype.low_bit_optim import _AdamW from torchao.quantization.quant_api import quantize_ model = ... quantize_(model, int8_weight_only_quantized_training()) -optim = AdamW(model.parameters(), lr=3e-4) +optim = _AdamW(model.parameters(), lr=3e-4) ``` -It is recommended to use optimizers from `torchao.prototype.low_bit_optim` for quantized training, because they can automatically generate efficient fused optimizer kernel for `dequant->optimizer_step->quant` thanks to `torch.compile()`. +Only `torch.optim.Adam` and optimizers from `torchao.prototype.low_bit_optim` are known to work with quantized training in this folder. This is because we implement stochastic rounding logic within tensor subclass instead of the optimizer. We provide `torchao.prototype.low_bit_optim._AdamW` as an alternative to `torch.optim.AdamW` specifically for this purpose. -[`benchmarks/benchmark_int8_qt.py`](../../../benchmarks/benchmark_int8_qt.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. +[`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. See [#644](https://github.com/pytorch/ao/pull/644) for some early results. @@ -39,4 +41,4 @@ See [#644](https://github.com/pytorch/ao/pull/644) for some early results. - INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores. - INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels). -- FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. +- FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy. diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 17c6a66bf1..085199ef12 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -12,10 +12,14 @@ _c10d_functional = torch.ops._c10d_functional -# the main difference of this tensor subclass from AffineQuantizedTensor: -# 1. F.linear is differentiable i.e. backward is defined. -# 2. support stochastic rounding when casting from floating point. class Int8QTLinearWeight(Tensor): + """INT8 symmetric quantization weight, with absmax scaling [-127, 127]. The main difference + of this tensor subclass from AffineQuantizedTensor: + 1. `F.linear` is differentiable i.e. backward is defined. + 2. All in-place ops, such as `aten.copy_`, will perform stochastic rounding. + `Int8QTLinearWeight.from_float()` does not perform stochastic rounding. + """ + implements = classmethod(_implements) __torch_function__ = classmethod(_dispatch__torch_function__) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @@ -52,6 +56,11 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No @staticmethod @torch.no_grad() def quantize(tensor: Tensor, stochastic_rounding: bool = False): + """Normal rounding will always round down small changes in weight update. To tackle this problem, + stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The + probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next + integer value. Thus, stochastic rounding also approximates the floating point value exactly. + """ original_dtype = tensor.dtype tensor = tensor.float() @@ -64,7 +73,6 @@ def quantize(tensor: Tensor, stochastic_rounding: bool = False): else: tensor = tensor.round() - # NOTE: is clipping necessary? tensor = tensor.clip(-128, 127).to(torch.int8) return tensor, scale.to(original_dtype) From 662c61f64825779ba14680876474a295c6bb805c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 16 Aug 2024 01:56:38 +0000 Subject: [PATCH 35/41] use torch.optim.Adam to avoid FSDP optim compile bug --- test/prototype/test_quantized_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 9c8904d9ec..4e4eb57ac8 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -187,8 +187,8 @@ def _test_fsdp2(self, compile_layer): fully_shard(layer) fully_shard(fsdp_model) - base_optim = _AdamW(base_model.parameters(), lr=1e-2) - fsdp_optim = _AdamW(fsdp_model.parameters(), lr=1e-2) + base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2) + fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): From cc90298c6a6743eafd371fc7cdd8719e630b67d3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 16 Aug 2024 02:03:44 +0000 Subject: [PATCH 36/41] update docs --- torchao/prototype/quantized_training/README.md | 9 +++++++++ torchao/prototype/quantized_training/int8.py | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index fb74255a3f..9b2980aa2b 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -37,6 +37,15 @@ Only `torch.optim.Adam` and optimizers from `torchao.prototype.low_bit_optim` ar See [#644](https://github.com/pytorch/ao/pull/644) for some early results. +TODO: investigate suboptimal memory saving when `torch.compile()` is used. Might be due to transposed weight. Memory benchamark for Llama2-1B, bs=4, seq_len=2048, activation checkpointing. + +Model | Peak memory (GB) +----------------|----------------- +BF16 eager | 11.06847 +BF16 compile | 10.16915 +INT8 QT eager | 10.11437 +INT8 QT compile | 10.03365 + ## Future ideas - INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores. diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 085199ef12..aa1d7344bc 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -60,6 +60,13 @@ def quantize(tensor: Tensor, stochastic_rounding: bool = False): stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next integer value. Thus, stochastic rounding also approximates the floating point value exactly. + + Currently this function differs from AQT's `int8_weight_only()` in the following way: + 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input + to FP32 before quantization, and downcast scale to original dtype. + 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is + done here. + 3. Apply scale: AQT uses `input * (1 / scale)`, while this function performs `input / scale`. """ original_dtype = tensor.dtype tensor = tensor.float() From f1c588b7fabce50addf4361c5ee089c08c08f8d6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 16 Aug 2024 02:05:59 +0000 Subject: [PATCH 37/41] update doc --- torchao/prototype/quantized_training/int8.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index aa1d7344bc..a40e209519 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -252,6 +252,9 @@ def _(func, types, args, kwargs): def int8_weight_only_quantized_training(): + # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` + # when we have this out of prototype (or there are stable trainable tensor subclasses), + # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. def apply_int8_linear_weight(linear: nn.Linear): linear.weight = nn.Parameter( Int8QTLinearWeight.from_float(linear.weight), From f444fa6580c3baccd2d6afb7dd35b54fc6ecf654 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 16 Aug 2024 02:13:17 +0000 Subject: [PATCH 38/41] update docs --- torchao/prototype/quantized_training/int8.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index a40e209519..c301f011c2 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -18,6 +18,8 @@ class Int8QTLinearWeight(Tensor): 1. `F.linear` is differentiable i.e. backward is defined. 2. All in-place ops, such as `aten.copy_`, will perform stochastic rounding. `Int8QTLinearWeight.from_float()` does not perform stochastic rounding. + 3. The numerics for quantization is slightly different. See `Int8QTLinearWeight.quantize()` + for more details. """ implements = classmethod(_implements) @@ -244,7 +246,11 @@ def _(func, types, args, kwargs): return Int8QTLinearWeight(int_data, scale) -# don't do anything. workaround for FSDP2. might give unexpected or wrong results. +# FSDP2 will call these two ops, expecting a view, not a copy. It doesn't make sense to +# correctly support these ops. For example, `.scale` depends on the shape of the weight, +# since this is channel-wise quantization. +# Thus, this is a workaround for FSDP2. Users SHOULD NOT call these ops directly, since +# they will produce unexpected or wrong results. @Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default]) def _(func, types, args, kwargs): out = Int8QTLinearWeight(args[0].int_data, args[0].scale) From 640ec2dd377aae30da29b7dde769315266fc39c2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 16 Aug 2024 03:03:57 +0000 Subject: [PATCH 39/41] fix CI test --- test/prototype/test_quantized_training.py | 24 ++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 4e4eb57ac8..714cc7d4cd 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -19,15 +19,16 @@ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) -# using TF32 will cause mixed mm to segfault with triton backend -# fixed by https://github.com/pytorch/pytorch/pull/133173 but just set here to be safe -# also required for correctness check -torch.set_float32_matmul_precision("highest") -# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI -# and make sure TF32 is not used (see above). +def _reset(): + # using TF32 will cause mixed mm to segfault with triton backend + # fixed in nightly by https://github.com/pytorch/pytorch/pull/133173 + # also required for correctness check + torch.set_float32_matmul_precision("highest") + torch._dynamo.reset() +# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI. class TestQuantizedTraining(TestCase): @parametrize("device", _DEVICES) def test_int8_stochastic_rounding(self, device): @@ -46,6 +47,7 @@ def test_int8_stochastic_rounding(self, device): @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) def test_int8_linear(self, leading_dims, bias, device): + _reset() embed_dim = 32 linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) @@ -76,7 +78,7 @@ def test_int8_linear(self, leading_dims, bias, device): @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) def test_int8_linear_compile(self, leading_dims, bias, device): - torch._dynamo.reset() + _reset() embed_dim = 128 linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) @@ -104,7 +106,7 @@ def test_int8_linear_compile(self, leading_dims, bias, device): @parametrize("compile", [False, True]) @parametrize("device", _DEVICES) def test_int8_linear_training(self, compile, device): - torch._dynamo.reset() + _reset() bsize = 4 embed_dim = 32 n_classes = 10 @@ -160,7 +162,7 @@ def _test_fsdp2(self, compile_layer): from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer - torch._dynamo.reset() + _reset() batch_size = 3 vocab_size = 32 seq_len = 64 @@ -187,8 +189,8 @@ def _test_fsdp2(self, compile_layer): fully_shard(layer) fully_shard(fsdp_model) - base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2) - fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) + base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) + fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): From dad6560f09cf6baaac0eeb50293edac60a5177e0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 16 Aug 2024 03:38:13 +0000 Subject: [PATCH 40/41] skip test --- test/prototype/test_quantized_training.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 714cc7d4cd..4c2bc99bc2 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -11,7 +11,7 @@ from torchao.prototype.low_bit_optim import _AdamW from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 if not TORCH_VERSION_AFTER_2_3: pytest.skip("Requires torch>=2.4", allow_module_level=True) @@ -158,6 +158,9 @@ def test_fsdp2(self): ) def _test_fsdp2(self, compile_layer): + if compile_layer and not TORCH_VERSION_AFTER_2_4: + pytest.skip("FSDP2 + compiled quantized training fails with PyTorch 2.4") + import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer From 4924e8d491bca99570073cdae36a673b353b58c8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 16 Aug 2024 04:54:16 +0000 Subject: [PATCH 41/41] fix compiled test --- test/prototype/test_quantized_training.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 4c2bc99bc2..6b4b6a6be9 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -152,15 +152,17 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fsdp2(self): + # FSDP2 + compiled quantized training fails with PyTorch 2.4 + compile_layer_choices = [False] + if TORCH_VERSION_AFTER_2_4: + compile_layer_choices.append(True) + self.run_subtests( - {"compile_layer": [False, True]}, + {"compile_layer": compile_layer_choices}, self._test_fsdp2, ) def _test_fsdp2(self, compile_layer): - if compile_layer and not TORCH_VERSION_AFTER_2_4: - pytest.skip("FSDP2 + compiled quantized training fails with PyTorch 2.4") - import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer