|
| 1 | +import copy |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +from torch import nn |
| 7 | +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
| 8 | +from torch.testing._internal.common_fsdp import FSDPTest |
| 9 | +from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests |
| 10 | + |
| 11 | +from torchao.prototype.low_bit_optim import _AdamW |
| 12 | +from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training |
| 13 | +from torchao.quantization.quant_api import quantize_ |
| 14 | +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 |
| 15 | + |
| 16 | +if not TORCH_VERSION_AFTER_2_3: |
| 17 | + pytest.skip("Requires torch>=2.4", allow_module_level=True) |
| 18 | + |
| 19 | + |
| 20 | +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) |
| 21 | + |
| 22 | + |
| 23 | +def _reset(): |
| 24 | + # using TF32 will cause mixed mm to segfault with triton backend |
| 25 | + # fixed in nightly by https://github.com/pytorch/pytorch/pull/133173 |
| 26 | + # also required for correctness check |
| 27 | + torch.set_float32_matmul_precision("highest") |
| 28 | + torch._dynamo.reset() |
| 29 | + |
| 30 | + |
| 31 | +# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI. |
| 32 | +class TestQuantizedTraining(TestCase): |
| 33 | + @parametrize("device", _DEVICES) |
| 34 | + def test_int8_stochastic_rounding(self, device): |
| 35 | + x = torch.randn(32, device=device) |
| 36 | + x_samples = x.view(1, -1).repeat(100_000, 1) |
| 37 | + |
| 38 | + x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True) |
| 39 | + x_dequant_samples = x_int8 * x_scale.view(-1, 1) |
| 40 | + x_dequant_mean = x_dequant_samples.mean(0) |
| 41 | + |
| 42 | + # a more rigorous test would be to do a hypothesis testing. |
| 43 | + # due to the statistical nature, this assertion may still fail, though very rarely. |
| 44 | + torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) |
| 45 | + |
| 46 | + @parametrize("leading_dims", [(), (2,), (2, 4)]) |
| 47 | + @parametrize("bias", [False, True]) |
| 48 | + @parametrize("device", _DEVICES) |
| 49 | + def test_int8_linear(self, leading_dims, bias, device): |
| 50 | + _reset() |
| 51 | + embed_dim = 32 |
| 52 | + |
| 53 | + linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) |
| 54 | + linear_int8 = copy.deepcopy(linear_fp32) |
| 55 | + quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False) |
| 56 | + linear_fp32.weight.data = linear_int8.weight.data.dequantize() |
| 57 | + |
| 58 | + input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device) |
| 59 | + input_int8 = input_fp32.clone() |
| 60 | + input_fp32.requires_grad_(True) |
| 61 | + input_int8.requires_grad_(True) |
| 62 | + |
| 63 | + # test forward |
| 64 | + out_fp32 = linear_fp32(input_fp32) |
| 65 | + out_int8 = linear_int8(input_int8) |
| 66 | + torch.testing.assert_close(out_fp32, out_int8) |
| 67 | + |
| 68 | + # test backward |
| 69 | + grad = torch.randn(leading_dims + (embed_dim,), device=device) |
| 70 | + out_fp32.backward(grad) |
| 71 | + out_int8.backward(grad) |
| 72 | + torch.testing.assert_close(input_fp32.grad, input_int8.grad) |
| 73 | + torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad) |
| 74 | + if bias: |
| 75 | + torch.testing.assert_close(linear_fp32.bias.grad, linear_int8.bias.grad) |
| 76 | + |
| 77 | + @parametrize("leading_dims", [(), (2,), (2, 4)]) |
| 78 | + @parametrize("bias", [False, True]) |
| 79 | + @parametrize("device", _DEVICES) |
| 80 | + def test_int8_linear_compile(self, leading_dims, bias, device): |
| 81 | + _reset() |
| 82 | + embed_dim = 128 |
| 83 | + |
| 84 | + linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) |
| 85 | + quantize_(linear_eager, int8_weight_only_quantized_training(), set_inductor_config=False) |
| 86 | + linear_compiled = copy.deepcopy(linear_eager) |
| 87 | + linear_compiled.compile() |
| 88 | + |
| 89 | + input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10 |
| 90 | + input_compiled = input_eager.clone() |
| 91 | + input_eager.requires_grad_(True) |
| 92 | + input_compiled.requires_grad_(True) |
| 93 | + |
| 94 | + out_eager = linear_eager(input_eager) |
| 95 | + out_compiled = linear_compiled(input_compiled) |
| 96 | + torch.testing.assert_close(out_eager, out_compiled) |
| 97 | + |
| 98 | + grad = torch.randn(leading_dims + (embed_dim,), device=device) |
| 99 | + out_eager.backward(grad) |
| 100 | + out_compiled.backward(grad) |
| 101 | + torch.testing.assert_close(input_eager.grad, input_compiled.grad) |
| 102 | + torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad) |
| 103 | + if bias: |
| 104 | + torch.testing.assert_close(linear_eager.bias.grad, linear_compiled.bias.grad) |
| 105 | + |
| 106 | + @parametrize("compile", [False, True]) |
| 107 | + @parametrize("device", _DEVICES) |
| 108 | + def test_int8_linear_training(self, compile, device): |
| 109 | + _reset() |
| 110 | + bsize = 4 |
| 111 | + embed_dim = 32 |
| 112 | + n_classes = 10 |
| 113 | + |
| 114 | + model_fp32 = nn.Sequential( |
| 115 | + nn.Linear(embed_dim, embed_dim * 2, bias=False), |
| 116 | + nn.GELU(), |
| 117 | + nn.Linear(embed_dim * 2, n_classes), |
| 118 | + ).to(device) |
| 119 | + model_int8 = copy.deepcopy(model_fp32) |
| 120 | + # don't set inductor flags to speed up CI time |
| 121 | + quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False) |
| 122 | + |
| 123 | + if compile: |
| 124 | + model_fp32.compile() |
| 125 | + model_int8.compile() |
| 126 | + |
| 127 | + optim_fp32 = _AdamW(model_fp32.parameters()) |
| 128 | + optim_int8 = _AdamW(model_int8.parameters()) |
| 129 | + |
| 130 | + for _ in range(5): |
| 131 | + inputs = torch.randn(bsize, embed_dim, device=device) |
| 132 | + labels = torch.randint(n_classes, size=(bsize,), device=device) |
| 133 | + loss_fp32 = F.cross_entropy(model_fp32(inputs), labels) |
| 134 | + loss_int8 = F.cross_entropy(model_int8(inputs), labels) |
| 135 | + |
| 136 | + rel_error = abs(loss_int8.item() - loss_fp32.item()) / abs(loss_fp32.item()) |
| 137 | + assert rel_error < 2e-3, rel_error |
| 138 | + |
| 139 | + loss_fp32.backward() |
| 140 | + optim_fp32.step() |
| 141 | + optim_fp32.zero_grad() |
| 142 | + |
| 143 | + loss_int8.backward() |
| 144 | + optim_int8.step() |
| 145 | + optim_int8.zero_grad() |
| 146 | + |
| 147 | + |
| 148 | +class TestFSDP2(FSDPTest): |
| 149 | + @property |
| 150 | + def world_size(self) -> int: |
| 151 | + return 2 |
| 152 | + |
| 153 | + @skip_if_lt_x_gpu(2) |
| 154 | + def test_fsdp2(self): |
| 155 | + # FSDP2 + compiled quantized training fails with PyTorch 2.4 |
| 156 | + compile_layer_choices = [False] |
| 157 | + if TORCH_VERSION_AFTER_2_4: |
| 158 | + compile_layer_choices.append(True) |
| 159 | + |
| 160 | + self.run_subtests( |
| 161 | + {"compile_layer": compile_layer_choices}, |
| 162 | + self._test_fsdp2, |
| 163 | + ) |
| 164 | + |
| 165 | + def _test_fsdp2(self, compile_layer): |
| 166 | + import torch.distributed as dist |
| 167 | + from torch.distributed._composable.fsdp import fully_shard |
| 168 | + from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer |
| 169 | + |
| 170 | + _reset() |
| 171 | + batch_size = 3 |
| 172 | + vocab_size = 32 |
| 173 | + seq_len = 64 |
| 174 | + model_args = ModelArgs( |
| 175 | + n_layers=2, |
| 176 | + n_heads=2, |
| 177 | + dim=128, |
| 178 | + vocab_size=vocab_size, |
| 179 | + max_seq_len=seq_len, |
| 180 | + dropout_p=0, |
| 181 | + ) |
| 182 | + torch.manual_seed(42) |
| 183 | + base_model = Transformer(model_args).cuda() |
| 184 | + quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False) |
| 185 | + fsdp_model = copy.deepcopy(base_model) |
| 186 | + |
| 187 | + if compile_layer: |
| 188 | + for layer in base_model.layers: |
| 189 | + layer.compile() |
| 190 | + |
| 191 | + for layer in fsdp_model.layers: |
| 192 | + if compile_layer: |
| 193 | + layer.compile() |
| 194 | + fully_shard(layer) |
| 195 | + fully_shard(fsdp_model) |
| 196 | + |
| 197 | + base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) |
| 198 | + fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) |
| 199 | + |
| 200 | + torch.manual_seed(42 + self.rank + 1) |
| 201 | + for iter_idx in range(5): |
| 202 | + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") |
| 203 | + fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) |
| 204 | + fsdp_loss = fsdp_model(inp).sum() |
| 205 | + fsdp_loss.backward() |
| 206 | + fsdp_optim.step() |
| 207 | + |
| 208 | + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) |
| 209 | + base_loss = base_model(inp).sum() |
| 210 | + base_loss.backward() |
| 211 | + for param in base_model.parameters(): |
| 212 | + if param.grad is not None: |
| 213 | + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) |
| 214 | + base_optim.step() |
| 215 | + |
| 216 | + # due to stochastic rounding, use a pretty large tolerance here |
| 217 | + rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() |
| 218 | + assert rel_error < 0.05, rel_error |
| 219 | + |
| 220 | + |
| 221 | +instantiate_parametrized_tests(TestQuantizedTraining) |
| 222 | + |
| 223 | + |
| 224 | +if __name__ == "__main__": |
| 225 | + run_tests() |
0 commit comments