Skip to content

Commit 8c6b4f9

Browse files
authored
Add experimental INT8 quantized training (#644)
* initial commit * add tests * add training * support py3.9 * skip test for torch<2.3 * fix pytest * fix adamw * add some FSDP ops * add more fsdp ops * more ops * add benchmark script * some organisation * add FSDP test * clean up * update FSDP test * add compile test (things are crashing) * fix bias * substantial update to tests * fix compile for FSDP * update readme. rename file * speed up CI * fix typo * fix typo * typos. unset some dynamo flags * update readme * remove requires_grad, since it is unnecessary * remove note * don't set inductor flags * rename * update README * rename optimizer * update benchmark script * make compile explicit * update docs * use torch.optim.Adam to avoid FSDP optim compile bug * update docs * update doc * update docs * fix CI test * skip test * fix compiled test
1 parent 0b66ff0 commit 8c6b4f9

File tree

8 files changed

+726
-12
lines changed

8 files changed

+726
-12
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
2+
# pip install transformers sentencepiece wandb
3+
#
4+
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
5+
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
6+
7+
import os
8+
9+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
10+
11+
import argparse
12+
from pathlib import Path
13+
14+
import numpy as np
15+
import torch
16+
import wandb
17+
from tqdm import tqdm
18+
from transformers import LlamaConfig, LlamaForCausalLM
19+
20+
from torchao.prototype import low_bit_optim
21+
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
22+
from torchao.quantization.quant_api import quantize_
23+
24+
25+
def get_loss(model: LlamaForCausalLM, batch: torch.Tensor):
26+
return model(batch, labels=batch).loss
27+
28+
29+
def get_tinystories():
30+
save_path = Path("tinystories.bin")
31+
32+
if not save_path.exists():
33+
import sentencepiece as spm
34+
from huggingface_hub import hf_hub_download
35+
36+
tokenizer_path = hf_hub_download("meta-llama/Llama-2-7b", "tokenizer.model")
37+
tokenizer = spm.SentencePieceProcessor(tokenizer_path)
38+
assert tokenizer.vocab_size() < (1 << 16) # make sure we can use uint16
39+
40+
# do everything in memory. we have enough RAM
41+
filepath = hf_hub_download(
42+
"roneneldan/TinyStories",
43+
"TinyStoriesV2-GPT4-train.txt",
44+
repo_type="dataset",
45+
)
46+
stories = open(filepath).read().split("\n<|endoftext|>\n")
47+
48+
tokens_list = []
49+
chunk_size = 10_000
50+
for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"):
51+
chunk = stories[i : min(i + chunk_size, len(stories))]
52+
tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4))
53+
54+
total_size = sum(len(x) for x in tokens_list)
55+
mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size)
56+
i = 0
57+
for tokens in tokens_list:
58+
mmap_tokens[i : i + len(tokens)] = tokens
59+
i += len(tokens)
60+
mmap_tokens.flush()
61+
62+
tokens = np.memmap(save_path, dtype=np.uint16, mode="r")
63+
return torch.from_numpy(tokens)
64+
65+
66+
if __name__ == "__main__":
67+
parser = argparse.ArgumentParser()
68+
# default config is 470M
69+
parser.add_argument("--d_model", type=int, default=1024)
70+
parser.add_argument("--depth", type=int, default=24)
71+
parser.add_argument("--ffn_size", type=int, default=4096)
72+
parser.add_argument("--head_dim", type=int, default=64)
73+
74+
parser.add_argument("--quantize")
75+
parser.add_argument("--activation_checkpointing", action="store_true")
76+
parser.add_argument("--compile", action="store_true")
77+
78+
parser.add_argument("--n_steps", type=int, default=1000)
79+
parser.add_argument("--batch_size", type=int, default=4)
80+
parser.add_argument("--seq_len", type=int, default=2048)
81+
82+
parser.add_argument("--optim", default="AdamW")
83+
parser.add_argument("--lr", type=float, default=3e-4)
84+
parser.add_argument("--weight_decay", type=float, default=1e-2)
85+
86+
parser.add_argument("--project", default="int8_quantized_training")
87+
parser.add_argument("--run_name")
88+
parser.add_argument("--seed", type=int)
89+
args = parser.parse_args()
90+
91+
if args.seed is not None:
92+
torch.manual_seed(args.seed)
93+
94+
config = LlamaConfig(
95+
hidden_size=args.d_model,
96+
intermediate_size=args.ffn_size,
97+
num_hidden_layers=args.depth,
98+
num_attention_heads=args.d_model // args.head_dim,
99+
max_position_embeddings=args.seq_len,
100+
use_cache=False,
101+
)
102+
model = LlamaForCausalLM(config).bfloat16().cuda()
103+
if args.activation_checkpointing:
104+
model.gradient_checkpointing_enable()
105+
if args.quantize == "int8_weight_only":
106+
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
107+
elif args.quantize is not None:
108+
raise ValueError(f"Unsupported quantize={args.quantize}")
109+
print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}")
110+
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")
111+
112+
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
113+
if args.optim == "AdamW":
114+
args.optim = "_AdamW"
115+
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
116+
117+
data = get_tinystories().cuda()
118+
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name)
119+
120+
step = 0
121+
log_interval = 50
122+
pbar = tqdm(total=args.n_steps, dynamic_ncols=True)
123+
model.train()
124+
_get_loss = torch.compile(get_loss) if args.compile else get_loss
125+
126+
while step < args.n_steps:
127+
# randomly select a continuous chunk, then reshape it
128+
idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item()
129+
batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long()
130+
131+
loss = _get_loss(model, batch)
132+
loss.backward()
133+
134+
if step % log_interval == 0:
135+
log_dict = dict(
136+
loss=loss.item(),
137+
lr=optim.param_groups[0]["lr"],
138+
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
139+
max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9,
140+
)
141+
run.log(log_dict, step=step)
142+
pbar.set_postfix(loss=log_dict["loss"])
143+
144+
optim.step()
145+
optim.zero_grad()
146+
147+
step += 1
148+
pbar.update()
149+
150+
run.finish()
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import copy
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
from torch import nn
7+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
8+
from torch.testing._internal.common_fsdp import FSDPTest
9+
from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests
10+
11+
from torchao.prototype.low_bit_optim import _AdamW
12+
from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training
13+
from torchao.quantization.quant_api import quantize_
14+
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
15+
16+
if not TORCH_VERSION_AFTER_2_3:
17+
pytest.skip("Requires torch>=2.4", allow_module_level=True)
18+
19+
20+
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
21+
22+
23+
def _reset():
24+
# using TF32 will cause mixed mm to segfault with triton backend
25+
# fixed in nightly by https://github.com/pytorch/pytorch/pull/133173
26+
# also required for correctness check
27+
torch.set_float32_matmul_precision("highest")
28+
torch._dynamo.reset()
29+
30+
31+
# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI.
32+
class TestQuantizedTraining(TestCase):
33+
@parametrize("device", _DEVICES)
34+
def test_int8_stochastic_rounding(self, device):
35+
x = torch.randn(32, device=device)
36+
x_samples = x.view(1, -1).repeat(100_000, 1)
37+
38+
x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True)
39+
x_dequant_samples = x_int8 * x_scale.view(-1, 1)
40+
x_dequant_mean = x_dequant_samples.mean(0)
41+
42+
# a more rigorous test would be to do a hypothesis testing.
43+
# due to the statistical nature, this assertion may still fail, though very rarely.
44+
torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4)
45+
46+
@parametrize("leading_dims", [(), (2,), (2, 4)])
47+
@parametrize("bias", [False, True])
48+
@parametrize("device", _DEVICES)
49+
def test_int8_linear(self, leading_dims, bias, device):
50+
_reset()
51+
embed_dim = 32
52+
53+
linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)
54+
linear_int8 = copy.deepcopy(linear_fp32)
55+
quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False)
56+
linear_fp32.weight.data = linear_int8.weight.data.dequantize()
57+
58+
input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device)
59+
input_int8 = input_fp32.clone()
60+
input_fp32.requires_grad_(True)
61+
input_int8.requires_grad_(True)
62+
63+
# test forward
64+
out_fp32 = linear_fp32(input_fp32)
65+
out_int8 = linear_int8(input_int8)
66+
torch.testing.assert_close(out_fp32, out_int8)
67+
68+
# test backward
69+
grad = torch.randn(leading_dims + (embed_dim,), device=device)
70+
out_fp32.backward(grad)
71+
out_int8.backward(grad)
72+
torch.testing.assert_close(input_fp32.grad, input_int8.grad)
73+
torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad)
74+
if bias:
75+
torch.testing.assert_close(linear_fp32.bias.grad, linear_int8.bias.grad)
76+
77+
@parametrize("leading_dims", [(), (2,), (2, 4)])
78+
@parametrize("bias", [False, True])
79+
@parametrize("device", _DEVICES)
80+
def test_int8_linear_compile(self, leading_dims, bias, device):
81+
_reset()
82+
embed_dim = 128
83+
84+
linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)
85+
quantize_(linear_eager, int8_weight_only_quantized_training(), set_inductor_config=False)
86+
linear_compiled = copy.deepcopy(linear_eager)
87+
linear_compiled.compile()
88+
89+
input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10
90+
input_compiled = input_eager.clone()
91+
input_eager.requires_grad_(True)
92+
input_compiled.requires_grad_(True)
93+
94+
out_eager = linear_eager(input_eager)
95+
out_compiled = linear_compiled(input_compiled)
96+
torch.testing.assert_close(out_eager, out_compiled)
97+
98+
grad = torch.randn(leading_dims + (embed_dim,), device=device)
99+
out_eager.backward(grad)
100+
out_compiled.backward(grad)
101+
torch.testing.assert_close(input_eager.grad, input_compiled.grad)
102+
torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad)
103+
if bias:
104+
torch.testing.assert_close(linear_eager.bias.grad, linear_compiled.bias.grad)
105+
106+
@parametrize("compile", [False, True])
107+
@parametrize("device", _DEVICES)
108+
def test_int8_linear_training(self, compile, device):
109+
_reset()
110+
bsize = 4
111+
embed_dim = 32
112+
n_classes = 10
113+
114+
model_fp32 = nn.Sequential(
115+
nn.Linear(embed_dim, embed_dim * 2, bias=False),
116+
nn.GELU(),
117+
nn.Linear(embed_dim * 2, n_classes),
118+
).to(device)
119+
model_int8 = copy.deepcopy(model_fp32)
120+
# don't set inductor flags to speed up CI time
121+
quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False)
122+
123+
if compile:
124+
model_fp32.compile()
125+
model_int8.compile()
126+
127+
optim_fp32 = _AdamW(model_fp32.parameters())
128+
optim_int8 = _AdamW(model_int8.parameters())
129+
130+
for _ in range(5):
131+
inputs = torch.randn(bsize, embed_dim, device=device)
132+
labels = torch.randint(n_classes, size=(bsize,), device=device)
133+
loss_fp32 = F.cross_entropy(model_fp32(inputs), labels)
134+
loss_int8 = F.cross_entropy(model_int8(inputs), labels)
135+
136+
rel_error = abs(loss_int8.item() - loss_fp32.item()) / abs(loss_fp32.item())
137+
assert rel_error < 2e-3, rel_error
138+
139+
loss_fp32.backward()
140+
optim_fp32.step()
141+
optim_fp32.zero_grad()
142+
143+
loss_int8.backward()
144+
optim_int8.step()
145+
optim_int8.zero_grad()
146+
147+
148+
class TestFSDP2(FSDPTest):
149+
@property
150+
def world_size(self) -> int:
151+
return 2
152+
153+
@skip_if_lt_x_gpu(2)
154+
def test_fsdp2(self):
155+
# FSDP2 + compiled quantized training fails with PyTorch 2.4
156+
compile_layer_choices = [False]
157+
if TORCH_VERSION_AFTER_2_4:
158+
compile_layer_choices.append(True)
159+
160+
self.run_subtests(
161+
{"compile_layer": compile_layer_choices},
162+
self._test_fsdp2,
163+
)
164+
165+
def _test_fsdp2(self, compile_layer):
166+
import torch.distributed as dist
167+
from torch.distributed._composable.fsdp import fully_shard
168+
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer
169+
170+
_reset()
171+
batch_size = 3
172+
vocab_size = 32
173+
seq_len = 64
174+
model_args = ModelArgs(
175+
n_layers=2,
176+
n_heads=2,
177+
dim=128,
178+
vocab_size=vocab_size,
179+
max_seq_len=seq_len,
180+
dropout_p=0,
181+
)
182+
torch.manual_seed(42)
183+
base_model = Transformer(model_args).cuda()
184+
quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False)
185+
fsdp_model = copy.deepcopy(base_model)
186+
187+
if compile_layer:
188+
for layer in base_model.layers:
189+
layer.compile()
190+
191+
for layer in fsdp_model.layers:
192+
if compile_layer:
193+
layer.compile()
194+
fully_shard(layer)
195+
fully_shard(fsdp_model)
196+
197+
base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False)
198+
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False)
199+
200+
torch.manual_seed(42 + self.rank + 1)
201+
for iter_idx in range(5):
202+
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
203+
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
204+
fsdp_loss = fsdp_model(inp).sum()
205+
fsdp_loss.backward()
206+
fsdp_optim.step()
207+
208+
base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
209+
base_loss = base_model(inp).sum()
210+
base_loss.backward()
211+
for param in base_model.parameters():
212+
if param.grad is not None:
213+
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
214+
base_optim.step()
215+
216+
# due to stochastic rounding, use a pretty large tolerance here
217+
rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs()
218+
assert rel_error < 0.05, rel_error
219+
220+
221+
instantiate_parametrized_tests(TestQuantizedTraining)
222+
223+
224+
if __name__ == "__main__":
225+
run_tests()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .adam import Adam8bit, Adam4bit, AdamFp8
2-
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8
2+
from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8
33
from .cpu_offload import CPUOffloadOptimizer

0 commit comments

Comments
 (0)