You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Move QAT out of prototype
Summary: Move QAT out of prototype so we can provide stronger
BC guarantees moving forward.
**(Future) BC-breaking notes**
Note: This commit itself doesn't break BC yet. A future PR
will do that. The following is just to save this BC breaking
note somewhere.
Before:
```
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer
Int8DynActInt4WeightQATQuantizer,
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.prototype.qat.api import (
FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
```
After:
```
from torchao.quantization.qat import (
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.api import (
FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
```
Test Plan:
python test/quantization/test_qat.py
ghstack-source-id: add9dca
Pull Request resolved: #1091
Copy file name to clipboardExpand all lines: README.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
59
59
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
60
60
61
61
```python
62
-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
62
+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
torchao currently supports two QAT schemes for linear layers:
25
-
- int8 per token dynamic activations + int4 per group weights
26
-
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
27
-
28
-
QAT typically involves applying a transformation to your model before and after training.
29
-
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
30
-
fake quantize operations into linear layers, and (2) convert transforms the fake quantize
31
-
operations to actual quantize and dequantize operations after training, thereby producing
32
-
a quantized model (dequantize operations are typically fused with linear after lowering).
33
-
Between these two steps, training can proceed exactly as before.
34
-
35
-

36
-
37
-
To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
38
-
training, then apply the convert step after training for inference or generation.
39
-
For example, on a single GPU:
40
-
41
-
```python
42
-
import torch
43
-
from torchtune.models.llama3 import llama3
44
-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
45
-
46
-
# Smaller version of llama3 to fit in a single GPU
47
-
model = llama3(
48
-
vocab_size=4096,
49
-
num_layers=16,
50
-
num_heads=16,
51
-
num_kv_heads=4,
52
-
embed_dim=2048,
53
-
max_seq_len=2048,
54
-
).cuda()
55
-
56
-
# Quantizer for int8 dynamic per token activations +
57
-
# int4 grouped per channel weights, only for linear layers
Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the
115
-
quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097).
0 commit comments