Skip to content

Commit cf573bd

Browse files
authored
[reland] Move QAT out of prototype (#1152)
Move QAT out of prototype Summary: Move QAT out of prototype so we can provide stronger BC guarantees moving forward. **(Future) BC-breaking notes** Note: This commit itself doesn't break BC yet. A future PR will do that. The following is just to save this BC breaking note somewhere. Before: ``` from torchao.quantization.prototype.qat import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.prototype.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.prototype.qat.fake_quantizer import ( FakeQuantizer, ) ``` After: ``` from torchao.quantization.qat import ( ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, ) from torchao.quantization.qat.linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) ``` Test Plan: python test/quantization/test_qat.py ghstack-source-id: add9dca Pull Request resolved: #1091
1 parent 629aee1 commit cf573bd

20 files changed

+1703
-1604
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
5959
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
6060

6161
```python
62-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
62+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
6363

6464
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
6565

test/quantization/test_qat.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@
2222
PerRow,
2323
PerToken,
2424
)
25-
from torchao.quantization.prototype.qat.api import (
25+
from torchao.quantization.qat.api import (
2626
ComposableQATQuantizer,
2727
FakeQuantizeConfig,
2828
)
29-
from torchao.quantization.prototype.qat.fake_quantizer import (
29+
from torchao.quantization.qat.fake_quantizer import (
3030
FakeQuantizer,
3131
)
32-
from torchao.quantization.prototype.qat.embedding import (
32+
from torchao.quantization.qat.embedding import (
3333
FakeQuantizedEmbedding,
3434
)
35-
from torchao.quantization.prototype.qat.linear import (
35+
from torchao.quantization.qat.linear import (
3636
FakeQuantizedLinear,
3737
Int8DynActInt4WeightQATLinear,
3838
Int4WeightOnlyQATLinear
3939
)
40-
from torchao.quantization.prototype.qat.utils import (
40+
from torchao.quantization.qat.utils import (
4141
_choose_qparams_per_token_asymmetric,
4242
_fake_quantize_per_channel_group,
4343
_fake_quantize_per_token,
@@ -181,7 +181,7 @@ def _set_ptq_weight(
181181
Int8DynActInt4WeightLinear,
182182
WeightOnlyInt4Linear,
183183
)
184-
from torchao.quantization.prototype.qat.linear import (
184+
from torchao.quantization.qat.linear import (
185185
Int8DynActInt4WeightQATLinear,
186186
Int4WeightOnlyQATLinear,
187187
)
@@ -213,7 +213,7 @@ def _set_ptq_weight(
213213

214214
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
215215
def test_qat_8da4w_linear(self):
216-
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
216+
from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear
217217
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
218218

219219
group_size = 128
@@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self):
238238

239239
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
240240
def test_qat_8da4w_quantizer(self):
241-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
241+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
242242
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
243243

244244
group_size = 16
@@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self):
272272

273273
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
274274
def test_qat_8da4w_quantizer_meta_weights(self):
275-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
275+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
276276

277277
with torch.device("meta"):
278278
m = M()
@@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
287287
"""
288288
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
289289
"""
290-
from torchao.quantization.prototype.qat import (
290+
from torchao.quantization.qat.linear import (
291291
Int8DynActInt4WeightQATQuantizer,
292292
disable_8da4w_fake_quant,
293293
enable_8da4w_fake_quant,
@@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
346346
"""
347347
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
348348
"""
349-
from torchao.quantization.prototype.qat import (
349+
from torchao.quantization.qat.linear import (
350350
Int8DynActInt4WeightQATQuantizer,
351351
disable_8da4w_fake_quant,
352352
)
@@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer):
428428

429429
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
430430
def test_qat_8da4w_quantizer_gradients(self):
431-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
431+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
432432
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
433433
self._test_qat_quantized_gradients(quantizer)
434434

@@ -518,7 +518,7 @@ def test_qat_4w_primitives(self):
518518
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
519519
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
520520
def test_qat_4w_linear(self):
521-
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
521+
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
522522
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
523523

524524
group_size = 128
@@ -545,14 +545,14 @@ def test_qat_4w_linear(self):
545545

546546
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
547547
def test_qat_4w_quantizer_gradients(self):
548-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
548+
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
549549
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
550550
self._test_qat_quantized_gradients(quantizer)
551551

552552
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
553553
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
554554
def test_qat_4w_quantizer(self):
555-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
555+
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
556556
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
557557

558558
group_size = 32
@@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self):
630630

631631
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
632632
def test_qat_4w_embedding(self):
633-
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
633+
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
634634
model = M2()
635635
x = model.example_inputs()
636636
out = model(*x)
@@ -937,6 +937,59 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
937937
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
938938
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
939939

940+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
941+
def test_qat_prototype_bc(self):
942+
"""
943+
Just to make sure we can import all the old prototype paths.
944+
We will remove this test in the near future when we actually break BC.
945+
"""
946+
from torchao.quantization.prototype.qat import (
947+
disable_4w_fake_quant,
948+
disable_8da4w_fake_quant,
949+
enable_4w_fake_quant,
950+
enable_8da4w_fake_quant,
951+
ComposableQATQuantizer,
952+
Int8DynActInt4WeightQATLinear,
953+
Int4WeightOnlyEmbeddingQATQuantizer,
954+
Int4WeightOnlyQATQuantizer,
955+
Int8DynActInt4WeightQATQuantizer,
956+
)
957+
from torchao.quantization.prototype.qat._module_swap_api import (
958+
disable_4w_fake_quant_module_swap,
959+
enable_4w_fake_quant_module_swap,
960+
disable_8da4w_fake_quant_module_swap,
961+
enable_8da4w_fake_quant_module_swap,
962+
Int4WeightOnlyQATQuantizerModuleSwap,
963+
Int8DynActInt4WeightQATQuantizerModuleSwap,
964+
)
965+
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
966+
AffineFakeQuantizedTensor,
967+
to_affine_fake_quantized,
968+
)
969+
from torchao.quantization.prototype.qat.api import (
970+
ComposableQATQuantizer,
971+
FakeQuantizeConfig,
972+
)
973+
from torchao.quantization.prototype.qat.embedding import (
974+
FakeQuantizedEmbedding,
975+
Int4WeightOnlyEmbeddingQATQuantizer,
976+
Int4WeightOnlyEmbedding,
977+
Int4WeightOnlyQATEmbedding,
978+
)
979+
from torchao.quantization.prototype.qat.fake_quantizer import (
980+
FakeQuantizer,
981+
)
982+
from torchao.quantization.prototype.qat.linear import (
983+
disable_4w_fake_quant,
984+
disable_8da4w_fake_quant,
985+
enable_4w_fake_quant,
986+
enable_8da4w_fake_quant,
987+
FakeQuantizedLinear,
988+
Int4WeightOnlyQATLinear,
989+
Int4WeightOnlyQATQuantizer,
990+
Int8DynActInt4WeightQATLinear,
991+
Int8DynActInt4WeightQATQuantizer,
992+
)
940993

941994
if __name__ == "__main__":
942-
unittest.main()
995+
unittest.main()
Lines changed: 3 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,3 @@
1-
# Quantization-Aware Training (QAT)
2-
3-
Quantization-Aware Training (QAT) refers to applying fake quantization during the
4-
training or fine-tuning process, such that the final quantized model will exhibit
5-
higher accuracies and perplexities. Fake quantization refers to rounding the float
6-
values to quantized values without actually casting them to dtypes with lower
7-
bit-widths, in contrast to post-training quantization (PTQ), which does cast the
8-
quantized values to lower bit-width dtypes, e.g.:
9-
10-
```
11-
# PTQ: x_q is quantized and cast to int8
12-
# scale and zero point (zp) refer to parameters used to quantize x_float
13-
# qmin and qmax refer to the range of quantized values
14-
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
15-
16-
# QAT: x_fq is still in float
17-
# Fake quantize simulates the numerics of quantize + dequantize
18-
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
19-
x_fq = (x_fq - zp) * scale
20-
```
21-
22-
## API
23-
24-
torchao currently supports two QAT schemes for linear layers:
25-
- int8 per token dynamic activations + int4 per group weights
26-
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
27-
28-
QAT typically involves applying a transformation to your model before and after training.
29-
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
30-
fake quantize operations into linear layers, and (2) convert transforms the fake quantize
31-
operations to actual quantize and dequantize operations after training, thereby producing
32-
a quantized model (dequantize operations are typically fused with linear after lowering).
33-
Between these two steps, training can proceed exactly as before.
34-
35-
![qat](images/qat_diagram.png)
36-
37-
To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
38-
training, then apply the convert step after training for inference or generation.
39-
For example, on a single GPU:
40-
41-
```python
42-
import torch
43-
from torchtune.models.llama3 import llama3
44-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
45-
46-
# Smaller version of llama3 to fit in a single GPU
47-
model = llama3(
48-
vocab_size=4096,
49-
num_layers=16,
50-
num_heads=16,
51-
num_kv_heads=4,
52-
embed_dim=2048,
53-
max_seq_len=2048,
54-
).cuda()
55-
56-
# Quantizer for int8 dynamic per token activations +
57-
# int4 grouped per channel weights, only for linear layers
58-
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
59-
60-
# Insert "fake quantize" operations into linear layers.
61-
# These operations simulate quantization numerics during
62-
# training without performing any dtype casting
63-
model = qat_quantizer.prepare(model)
64-
65-
# Standard training loop
66-
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
67-
loss_fn = torch.nn.CrossEntropyLoss()
68-
for i in range(10):
69-
example = torch.randint(0, 4096, (2, 16)).cuda()
70-
target = torch.randn((2, 16, 4096)).cuda()
71-
output = model(example)
72-
loss = loss_fn(output, target)
73-
loss.backward()
74-
optimizer.step()
75-
optimizer.zero_grad()
76-
77-
# Convert fake quantize to actual quantize operations
78-
# The quantized model has the exact same structure as the
79-
# quantized model produced in the corresponding PTQ flow
80-
# through `Int8DynActInt4WeightQuantizer`
81-
model = qat_quantizer.convert(model)
82-
83-
# inference or generate
84-
```
85-
86-
Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
87-
and apply quantized-aware fine-tuning as follows:
88-
89-
```
90-
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
91-
```
92-
93-
For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
94-
95-
96-
## Evaluation Results
97-
98-
Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT
99-
integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
100-
on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset)
101-
for 5000 steps using a group size of 256 for the weights. Note that extensive
102-
hyperparameter tuning may further improve these results.
103-
104-
Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5:
105-
106-
| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
107-
| ---------------- | ------ | ------ | ------ | ------ | ------ |
108-
| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 |
109-
| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 |
110-
| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 |
111-
| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 |
112-
| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 |
113-
114-
Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the
115-
quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097).
116-
117-
| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
118-
| ---------------- | -------- | ------- | ------ | ------ | ------ |
119-
| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 |
120-
| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 |
121-
| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 |
122-
| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 |
123-
| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 |
124-
125-
For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training).
1+
Note: QAT has been moved to torchao/quantization/qat.
2+
This is a legacy folder only for backward compatibility
3+
and will be removed in the near future.

torchao/quantization/prototype/qat/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
from .api import (
1+
from torchao.quantization.qat import (
22
ComposableQATQuantizer,
3+
Int4WeightOnlyEmbeddingQATQuantizer,
4+
Int4WeightOnlyQATQuantizer,
5+
Int8DynActInt4WeightQATQuantizer,
36
)
4-
from .linear import (
7+
from torchao.quantization.qat.linear import (
58
disable_4w_fake_quant,
69
disable_8da4w_fake_quant,
710
enable_4w_fake_quant,
811
enable_8da4w_fake_quant,
9-
Int4WeightOnlyQATQuantizer,
1012
Int8DynActInt4WeightQATLinear,
11-
Int8DynActInt4WeightQATQuantizer,
12-
)
13-
from .embedding import (
14-
Int4WeightOnlyEmbeddingQATQuantizer,
1513
)
1614

1715
__all__ = [

torchao/quantization/prototype/qat/_module_swap_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# For backward compatibility only
22
# These will be removed in the future
33

4-
from .linear import (
4+
from torchao.quantization.qat.linear import (
55
Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap,
66
Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap,
77
enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap,

0 commit comments

Comments
 (0)