diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py new file mode 100644 index 0000000000..344a3a71af --- /dev/null +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -0,0 +1,150 @@ +# pre-train a mini Llama2 on TinyStories with INT8 quantized training +# pip install transformers sentencepiece wandb +# +# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile +# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only + +import os + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +import argparse +from pathlib import Path + +import numpy as np +import torch +import wandb +from tqdm import tqdm +from transformers import LlamaConfig, LlamaForCausalLM + +from torchao.prototype import low_bit_optim +from torchao.prototype.quantized_training import int8_weight_only_quantized_training +from torchao.quantization.quant_api import quantize_ + + +def get_loss(model: LlamaForCausalLM, batch: torch.Tensor): + return model(batch, labels=batch).loss + + +def get_tinystories(): + save_path = Path("tinystories.bin") + + if not save_path.exists(): + import sentencepiece as spm + from huggingface_hub import hf_hub_download + + tokenizer_path = hf_hub_download("meta-llama/Llama-2-7b", "tokenizer.model") + tokenizer = spm.SentencePieceProcessor(tokenizer_path) + assert tokenizer.vocab_size() < (1 << 16) # make sure we can use uint16 + + # do everything in memory. we have enough RAM + filepath = hf_hub_download( + "roneneldan/TinyStories", + "TinyStoriesV2-GPT4-train.txt", + repo_type="dataset", + ) + stories = open(filepath).read().split("\n<|endoftext|>\n") + + tokens_list = [] + chunk_size = 10_000 + for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"): + chunk = stories[i : min(i + chunk_size, len(stories))] + tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4)) + + total_size = sum(len(x) for x in tokens_list) + mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size) + i = 0 + for tokens in tokens_list: + mmap_tokens[i : i + len(tokens)] = tokens + i += len(tokens) + mmap_tokens.flush() + + tokens = np.memmap(save_path, dtype=np.uint16, mode="r") + return torch.from_numpy(tokens) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # default config is 470M + parser.add_argument("--d_model", type=int, default=1024) + parser.add_argument("--depth", type=int, default=24) + parser.add_argument("--ffn_size", type=int, default=4096) + parser.add_argument("--head_dim", type=int, default=64) + + parser.add_argument("--quantize") + parser.add_argument("--activation_checkpointing", action="store_true") + parser.add_argument("--compile", action="store_true") + + parser.add_argument("--n_steps", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=2048) + + parser.add_argument("--optim", default="AdamW") + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--weight_decay", type=float, default=1e-2) + + parser.add_argument("--project", default="int8_quantized_training") + parser.add_argument("--run_name") + parser.add_argument("--seed", type=int) + args = parser.parse_args() + + if args.seed is not None: + torch.manual_seed(args.seed) + + config = LlamaConfig( + hidden_size=args.d_model, + intermediate_size=args.ffn_size, + num_hidden_layers=args.depth, + num_attention_heads=args.d_model // args.head_dim, + max_position_embeddings=args.seq_len, + use_cache=False, + ) + model = LlamaForCausalLM(config).bfloat16().cuda() + if args.activation_checkpointing: + model.gradient_checkpointing_enable() + if args.quantize == "int8_weight_only": + quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + elif args.quantize is not None: + raise ValueError(f"Unsupported quantize={args.quantize}") + print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") + print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") + + # only use optimizers from torchao.prototype.low_bit_optim to support quantized training + if args.optim == "AdamW": + args.optim = "_AdamW" + optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + data = get_tinystories().cuda() + run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) + + step = 0 + log_interval = 50 + pbar = tqdm(total=args.n_steps, dynamic_ncols=True) + model.train() + _get_loss = torch.compile(get_loss) if args.compile else get_loss + + while step < args.n_steps: + # randomly select a continuous chunk, then reshape it + idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() + batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() + + loss = _get_loss(model, batch) + loss.backward() + + if step % log_interval == 0: + log_dict = dict( + loss=loss.item(), + lr=optim.param_groups[0]["lr"], + max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, + max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9, + ) + run.log(log_dict, step=step) + pbar.set_postfix(loss=log_dict["loss"]) + + optim.step() + optim.zero_grad() + + step += 1 + pbar.update() + + run.finish() diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py new file mode 100644 index 0000000000..6b4b6a6be9 --- /dev/null +++ b/test/prototype/test_quantized_training.py @@ -0,0 +1,225 @@ +import copy + +import pytest +import torch +import torch.nn.functional as F +from torch import nn +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests + +from torchao.prototype.low_bit_optim import _AdamW +from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training +from torchao.quantization.quant_api import quantize_ +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 + +if not TORCH_VERSION_AFTER_2_3: + pytest.skip("Requires torch>=2.4", allow_module_level=True) + + +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +def _reset(): + # using TF32 will cause mixed mm to segfault with triton backend + # fixed in nightly by https://github.com/pytorch/pytorch/pull/133173 + # also required for correctness check + torch.set_float32_matmul_precision("highest") + torch._dynamo.reset() + + +# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI. +class TestQuantizedTraining(TestCase): + @parametrize("device", _DEVICES) + def test_int8_stochastic_rounding(self, device): + x = torch.randn(32, device=device) + x_samples = x.view(1, -1).repeat(100_000, 1) + + x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True) + x_dequant_samples = x_int8 * x_scale.view(-1, 1) + x_dequant_mean = x_dequant_samples.mean(0) + + # a more rigorous test would be to do a hypothesis testing. + # due to the statistical nature, this assertion may still fail, though very rarely. + torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) + + @parametrize("leading_dims", [(), (2,), (2, 4)]) + @parametrize("bias", [False, True]) + @parametrize("device", _DEVICES) + def test_int8_linear(self, leading_dims, bias, device): + _reset() + embed_dim = 32 + + linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) + linear_int8 = copy.deepcopy(linear_fp32) + quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False) + linear_fp32.weight.data = linear_int8.weight.data.dequantize() + + input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device) + input_int8 = input_fp32.clone() + input_fp32.requires_grad_(True) + input_int8.requires_grad_(True) + + # test forward + out_fp32 = linear_fp32(input_fp32) + out_int8 = linear_int8(input_int8) + torch.testing.assert_close(out_fp32, out_int8) + + # test backward + grad = torch.randn(leading_dims + (embed_dim,), device=device) + out_fp32.backward(grad) + out_int8.backward(grad) + torch.testing.assert_close(input_fp32.grad, input_int8.grad) + torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad) + if bias: + torch.testing.assert_close(linear_fp32.bias.grad, linear_int8.bias.grad) + + @parametrize("leading_dims", [(), (2,), (2, 4)]) + @parametrize("bias", [False, True]) + @parametrize("device", _DEVICES) + def test_int8_linear_compile(self, leading_dims, bias, device): + _reset() + embed_dim = 128 + + linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) + quantize_(linear_eager, int8_weight_only_quantized_training(), set_inductor_config=False) + linear_compiled = copy.deepcopy(linear_eager) + linear_compiled.compile() + + input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10 + input_compiled = input_eager.clone() + input_eager.requires_grad_(True) + input_compiled.requires_grad_(True) + + out_eager = linear_eager(input_eager) + out_compiled = linear_compiled(input_compiled) + torch.testing.assert_close(out_eager, out_compiled) + + grad = torch.randn(leading_dims + (embed_dim,), device=device) + out_eager.backward(grad) + out_compiled.backward(grad) + torch.testing.assert_close(input_eager.grad, input_compiled.grad) + torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad) + if bias: + torch.testing.assert_close(linear_eager.bias.grad, linear_compiled.bias.grad) + + @parametrize("compile", [False, True]) + @parametrize("device", _DEVICES) + def test_int8_linear_training(self, compile, device): + _reset() + bsize = 4 + embed_dim = 32 + n_classes = 10 + + model_fp32 = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 2, bias=False), + nn.GELU(), + nn.Linear(embed_dim * 2, n_classes), + ).to(device) + model_int8 = copy.deepcopy(model_fp32) + # don't set inductor flags to speed up CI time + quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False) + + if compile: + model_fp32.compile() + model_int8.compile() + + optim_fp32 = _AdamW(model_fp32.parameters()) + optim_int8 = _AdamW(model_int8.parameters()) + + for _ in range(5): + inputs = torch.randn(bsize, embed_dim, device=device) + labels = torch.randint(n_classes, size=(bsize,), device=device) + loss_fp32 = F.cross_entropy(model_fp32(inputs), labels) + loss_int8 = F.cross_entropy(model_int8(inputs), labels) + + rel_error = abs(loss_int8.item() - loss_fp32.item()) / abs(loss_fp32.item()) + assert rel_error < 2e-3, rel_error + + loss_fp32.backward() + optim_fp32.step() + optim_fp32.zero_grad() + + loss_int8.backward() + optim_int8.step() + optim_int8.zero_grad() + + +class TestFSDP2(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(2) + def test_fsdp2(self): + # FSDP2 + compiled quantized training fails with PyTorch 2.4 + compile_layer_choices = [False] + if TORCH_VERSION_AFTER_2_4: + compile_layer_choices.append(True) + + self.run_subtests( + {"compile_layer": compile_layer_choices}, + self._test_fsdp2, + ) + + def _test_fsdp2(self, compile_layer): + import torch.distributed as dist + from torch.distributed._composable.fsdp import fully_shard + from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer + + _reset() + batch_size = 3 + vocab_size = 32 + seq_len = 64 + model_args = ModelArgs( + n_layers=2, + n_heads=2, + dim=128, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + torch.manual_seed(42) + base_model = Transformer(model_args).cuda() + quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False) + fsdp_model = copy.deepcopy(base_model) + + if compile_layer: + for layer in base_model.layers: + layer.compile() + + for layer in fsdp_model.layers: + if compile_layer: + layer.compile() + fully_shard(layer) + fully_shard(fsdp_model) + + base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) + fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) + + torch.manual_seed(42 + self.rank + 1) + for iter_idx in range(5): + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + fsdp_loss = fsdp_model(inp).sum() + fsdp_loss.backward() + fsdp_optim.step() + + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + base_loss = base_model(inp).sum() + base_loss.backward() + for param in base_model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) + base_optim.step() + + # due to stochastic rounding, use a pretty large tolerance here + rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() + assert rel_error < 0.05, rel_error + + +instantiate_parametrized_tests(TestQuantizedTraining) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index 01729bc6a3..5e9cc50c67 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,3 +1,3 @@ from .adam import Adam8bit, Adam4bit, AdamFp8 -from .adamw import AdamW8bit, AdamW4bit, AdamWFp8 +from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8 from .cpu_offload import CPUOffloadOptimizer diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 47a99c06dc..a5425e9840 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -10,7 +10,7 @@ from .subclass_fp8 import OptimStateFp8 -class _Adam(Optimizer): +class _AdamBase(Optimizer): def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -155,7 +155,7 @@ def single_param_adam( p.addcdiv_(new_exp_avg, denom, value=-step_size) -class Adam8bit(_Adam): +class Adam8bit(_AdamBase): def __init__( self, params, @@ -174,7 +174,7 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): return OptimState8bit.zeros(p.shape, signed, block_size, p.device) -class Adam4bit(_Adam): +class Adam4bit(_AdamBase): def __init__( self, params, @@ -233,7 +233,7 @@ def step(self, closure=None): return loss -class AdamFp8(_Adam): +class AdamFp8(_AdamBase): def __init__( self, params, diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index dbde91fdd2..9d1df8e6c8 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -10,7 +10,7 @@ from .subclass_fp8 import OptimStateFp8 -class _AdamW(Optimizer): +class _AdamWBase(Optimizer): def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -131,8 +131,6 @@ def single_param_adamw( weight_decay: float, eps: float, ): - p.mul_(1 - lr * weight_decay) - bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step @@ -150,11 +148,27 @@ def single_param_adamw( else: denom = (new_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) + # merge weight decay and param update in a single .add_() to make this work with quantized param step_size = lr / bias_correction1 - p.addcdiv_(new_exp_avg, denom, value=-step_size) + p.add_(-lr * weight_decay * p - step_size * new_exp_avg / denom) + + +class _AdamW(_AdamWBase): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + ) -> None: + """AdamW optimizer that supports quantized training (parameter is quantized). This optimizer should + only be used with torchao's quantized training.""" + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=float("inf")) -class AdamW8bit(_AdamW): +class AdamW8bit(_AdamWBase): def __init__( self, params, @@ -173,7 +187,7 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): return OptimState8bit.zeros(p.shape, signed, block_size, p.device) -class AdamW4bit(_AdamW): +class AdamW4bit(_AdamWBase): def __init__( self, params, @@ -232,7 +246,7 @@ def step(self, closure=None): return loss -class AdamWFp8(_AdamW): +class AdamWFp8(_AdamWBase): def __init__( self, params, diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md new file mode 100644 index 0000000000..9b2980aa2b --- /dev/null +++ b/torchao/prototype/quantized_training/README.md @@ -0,0 +1,53 @@ +# Quantized training + +This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from: +- Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] +- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] + +Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. + +In precise terms, the probability of rounding up is `x - ⌊x⌋`. Note that when the value is exactly an integer value, the probability of rounding up is zero. + +There are 2 main benefits for training in this way: +1. Reduce memory footprint. Also reduce communication bandwidth in distributed setting. +2. What you train is what you serve ([WYTIWYS](https://github.com/google/aqt?tab=readme-ov-file#features)). + +Currently we only support weight-only channel-wise INT8 symmetric quantization. + +## INT8 weight only + +In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization `[-127, 127]`. In the forward and backward pass, the weights are upcast to activations' dtype (e.g. BF16). Therefore, their gradients are also in activations' dtype. + +Usage + +```python +from torchao.prototype.quantized_training import int8_weight_only_quantized_training +from torchao.prototype.low_bit_optim import _AdamW +from torchao.quantization.quant_api import quantize_ + +model = ... +quantize_(model, int8_weight_only_quantized_training()) + +optim = _AdamW(model.parameters(), lr=3e-4) +``` + +Only `torch.optim.Adam` and optimizers from `torchao.prototype.low_bit_optim` are known to work with quantized training in this folder. This is because we implement stochastic rounding logic within tensor subclass instead of the optimizer. We provide `torchao.prototype.low_bit_optim._AdamW` as an alternative to `torch.optim.AdamW` specifically for this purpose. + +[`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. + +See [#644](https://github.com/pytorch/ao/pull/644) for some early results. + +TODO: investigate suboptimal memory saving when `torch.compile()` is used. Might be due to transposed weight. Memory benchamark for Llama2-1B, bs=4, seq_len=2048, activation checkpointing. + +Model | Peak memory (GB) +----------------|----------------- +BF16 eager | 11.06847 +BF16 compile | 10.16915 +INT8 QT eager | 10.11437 +INT8 QT compile | 10.03365 + +## Future ideas + +- INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores. +- INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels). +- FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy. diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py new file mode 100644 index 0000000000..6c7f8eb9b1 --- /dev/null +++ b/torchao/prototype/quantized_training/__init__.py @@ -0,0 +1 @@ +from .int8 import Int8QTLinearWeight, int8_weight_only_quantized_training diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py new file mode 100644 index 0000000000..c301f011c2 --- /dev/null +++ b/torchao/prototype/quantized_training/int8.py @@ -0,0 +1,271 @@ +from typing import Any, Optional, Tuple + +import torch +from torch import Tensor, nn +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.utils import _dispatch__torch_dispatch__, _dispatch__torch_function__, _implements + + +aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional + + +class Int8QTLinearWeight(Tensor): + """INT8 symmetric quantization weight, with absmax scaling [-127, 127]. The main difference + of this tensor subclass from AffineQuantizedTensor: + 1. `F.linear` is differentiable i.e. backward is defined. + 2. All in-place ops, such as `aten.copy_`, will perform stochastic rounding. + `Int8QTLinearWeight.from_float()` does not perform stochastic rounding. + 3. The numerics for quantization is slightly different. See `Int8QTLinearWeight.quantize()` + for more details. + """ + + implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + + @staticmethod + @torch._dynamo.disable + def __new__(cls, int_data: Tensor, scale: Tensor): + return Tensor._make_wrapper_subclass( + cls, + int_data.shape, + dtype=scale.dtype, + device=int_data.device, + ) + + @torch._dynamo.disable + def __init__(self, int_data: Tensor, scale: Tensor): + """Create a symmetric quantized INT8 weight. This tensor will appear to have the same dtype + as `scale.dtype`. All in-place update ops will perform stochastic rounding. + """ + # NOTE: should scale always be FP32? + assert int_data.dtype is torch.int8 + assert int_data.ndim == 2 + assert scale.ndim == 1 + self.int_data = int_data + self.scale = scale + + def __tensor_flatten__(self): + return ["int_data", "scale"], [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["int_data"], tensor_data_dict["scale"], *tensor_attributes) + + @staticmethod + @torch.no_grad() + def quantize(tensor: Tensor, stochastic_rounding: bool = False): + """Normal rounding will always round down small changes in weight update. To tackle this problem, + stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The + probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next + integer value. Thus, stochastic rounding also approximates the floating point value exactly. + + Currently this function differs from AQT's `int8_weight_only()` in the following way: + 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input + to FP32 before quantization, and downcast scale to original dtype. + 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is + done here. + 3. Apply scale: AQT uses `input * (1 / scale)`, while this function performs `input / scale`. + """ + original_dtype = tensor.dtype + tensor = tensor.float() + + # absmax symmetric quantization + scale = tensor.abs().amax(-1) / 127 + tensor = tensor / scale.clip(1e-12).view(-1, 1) + + if stochastic_rounding: + tensor = (tensor + torch.rand_like(tensor)).floor() + else: + tensor = tensor.round() + + tensor = tensor.clip(-128, 127).to(torch.int8) + return tensor, scale.to(original_dtype) + + @classmethod + def from_float(cls, tensor: Tensor): + """Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed. + This function is not differentiable. + """ + int_data, scale = cls.quantize(tensor.detach()) + out = cls(int_data, scale) + out.requires_grad_(tensor.requires_grad) + return out + + def dequantize(self): + return self.int_data * self.scale.view(-1, 1) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(shape={tuple(self.shape)}, dtype={self.dtype}, device={self.device}, " + f"requires_grad={self.requires_grad})" + ) + + def fsdp_pre_all_gather(self, mesh): + return (self.int_data, self.scale), None + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Tensor] = None, + ): + int_data, scale = all_gather_outputs + return Int8QTLinearWeight(int_data, scale), all_gather_outputs + + +class _Int8WeightOnlyLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tensor] = None): + ctx.save_for_backward(input, weight) + ctx.bias = bias is not None + + # NOTE: we have to .T before .to(input.dtype) for torch.compile() mixed matmul to work + out = (input @ weight.int_data.T.to(input.dtype)) * weight.scale + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + + grad_input = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) + grad_weight = grad_output.view(-1, weight.shape[0]).T @ input.view(-1, weight.shape[1]) + grad_bias = grad_output.view(-1, weight.shape[0]).sum(0) if ctx.bias else None + return grad_input, grad_weight, grad_bias + + +@Int8QTLinearWeight.implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + return _Int8WeightOnlyLinear.apply(*args, **kwargs) + + +@Int8QTLinearWeight.implements( + [ + aten.detach.default, + aten.clone.default, + # FSDP ops + aten.slice.Tensor, + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) +def _(func, types, args, kwargs): + # will error out if try to slice 2nd dim + out = Int8QTLinearWeight( + func(args[0].int_data, *args[1:], **kwargs), + func(args[0].scale, *args[1:], **kwargs), + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@Int8QTLinearWeight.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + # only perform dtype casting on scale, which determines the appearance dtype + # TODO: handle non_blocking kwarg? + device = kwargs.get("device", None) + dtype = kwargs.get("dtype", None) + out = Int8QTLinearWeight( + args[0].int_data.to(device=device), + args[0].scale.to(device=device, dtype=dtype), + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# to make training work with existing PyTorch optimizers, we return a normal tensor +@Int8QTLinearWeight.implements(aten.zeros_like.default) +def _(func, types, args, kwargs): + dtype = kwargs.get("dtype", args[0].dtype) + device = kwargs.get("device", args[0].device) + return torch.zeros(args[0].shape, dtype=dtype, device=device) + + +# out-of-place math ops always return plain tensor +@Int8QTLinearWeight.implements([aten.sub.Tensor, aten.mul.Tensor]) +def _(func, types, args, kwargs): + args = [x.dequantize() if isinstance(x, Int8QTLinearWeight) else x for x in args] + return func(*args, **kwargs) + + +@Int8QTLinearWeight.implements(aten.copy_.default) +def _(func, types, args, kwargs): + if isinstance(args[0], Int8QTLinearWeight) and isinstance(args[1], Int8QTLinearWeight): + args[0].int_data.copy_(args[1].int_data, **kwargs) + args[0].scale.copy_(args[1].scale, **kwargs) + + elif isinstance(args[0], Int8QTLinearWeight): + int_data, scale = Int8QTLinearWeight.quantize(args[1], stochastic_rounding=True) + args[0].int_data.copy_(int_data, **kwargs) + args[0].scale.copy_(scale, **kwargs) + + else: + args[0].copy_(args[1].dequantize(), **kwargs) + + return args[0] + + +@Int8QTLinearWeight.implements([aten.addcdiv_.default, aten.add_.Tensor]) +def _(func, types, args, kwargs): + original = args[0] + out = func(args[0].dequantize(), *args[1:], **kwargs) + return original.copy_(out) + + +# FSDP ops +@Int8QTLinearWeight.implements(aten.split.Tensor) +def _(func, types, args, kwargs): + if len(args) == 3 and args[2] != 0: + raise NotImplementedError("Int8QTLinearWeight only supports split at dim=0") + + int8_weight: Int8QTLinearWeight = args[0] + int_data_list = func(int8_weight.int_data, *args[1:], **kwargs) + scale_list = func(int8_weight.scale, *args[1:], **kwargs) + + out = [Int8QTLinearWeight(int_data, scale) for int_data, scale in zip(int_data_list, scale_list)] + return out + + +@Int8QTLinearWeight.implements(aten.new_zeros.default) +def _(func, types, args, kwargs): + size = args[1] + if len(size) != 2: + raise NotImplementedError + + # TODO: handle pin_memory kwarg? + device = kwargs.get("device", args[0].device) + dtype = kwargs.get("dtype", args[0].dtype) + int_data = torch.zeros(size, device=device, dtype=torch.int8) + scale = torch.zeros(size[0], device=device, dtype=dtype) + return Int8QTLinearWeight(int_data, scale) + + +# FSDP2 will call these two ops, expecting a view, not a copy. It doesn't make sense to +# correctly support these ops. For example, `.scale` depends on the shape of the weight, +# since this is channel-wise quantization. +# Thus, this is a workaround for FSDP2. Users SHOULD NOT call these ops directly, since +# they will produce unexpected or wrong results. +@Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default]) +def _(func, types, args, kwargs): + out = Int8QTLinearWeight(args[0].int_data, args[0].scale) + return return_and_correct_aliasing(func, args, kwargs, out) + + +def int8_weight_only_quantized_training(): + # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` + # when we have this out of prototype (or there are stable trainable tensor subclasses), + # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. + def apply_int8_linear_weight(linear: nn.Linear): + linear.weight = nn.Parameter( + Int8QTLinearWeight.from_float(linear.weight), + requires_grad=linear.weight.requires_grad, + ) + return linear + + return apply_int8_linear_weight