Skip to content

Commit 35dbae7

Browse files
committed
Move QAT out of prototype
Summary: Move QAT out of prototype so we can provide stronger BC guarantees moving forward. **BC-breaking notes** Before: ``` from torchao.quantization.prototype.qat import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.prototype.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.prototype.qat.fake_quantizer import ( FakeQuantizer, ) ``` After: ``` from torchao.quantization.qat import ( ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, ) from torchao.quantization.qat.linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) ``` Test Plan: python test/quantization/test_qat.py ghstack-source-id: cb72a8b Pull Request resolved: #1091
1 parent 7aaf0ff commit 35dbae7

File tree

13 files changed

+21
-42
lines changed

13 files changed

+21
-42
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
5959
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
6060

6161
```python
62-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
62+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
6363

6464
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
6565

test/quantization/test_qat.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@
2222
PerRow,
2323
PerToken,
2424
)
25-
from torchao.quantization.prototype.qat.api import (
25+
from torchao.quantization.qat.api import (
2626
ComposableQATQuantizer,
2727
FakeQuantizeConfig,
2828
)
29-
from torchao.quantization.prototype.qat.fake_quantizer import (
29+
from torchao.quantization.qat.fake_quantizer import (
3030
FakeQuantizer,
3131
)
32-
from torchao.quantization.prototype.qat.embedding import (
32+
from torchao.quantization.qat.embedding import (
3333
FakeQuantizedEmbedding,
3434
)
35-
from torchao.quantization.prototype.qat.linear import (
35+
from torchao.quantization.qat.linear import (
3636
FakeQuantizedLinear,
3737
Int8DynActInt4WeightQATLinear,
3838
Int4WeightOnlyQATLinear
3939
)
40-
from torchao.quantization.prototype.qat.utils import (
40+
from torchao.quantization.qat.utils import (
4141
_choose_qparams_per_token_asymmetric,
4242
_fake_quantize_per_channel_group,
4343
_fake_quantize_per_token,
@@ -181,7 +181,7 @@ def _set_ptq_weight(
181181
Int8DynActInt4WeightLinear,
182182
WeightOnlyInt4Linear,
183183
)
184-
from torchao.quantization.prototype.qat.linear import (
184+
from torchao.quantization.qat.linear import (
185185
Int8DynActInt4WeightQATLinear,
186186
Int4WeightOnlyQATLinear,
187187
)
@@ -213,7 +213,7 @@ def _set_ptq_weight(
213213

214214
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
215215
def test_qat_8da4w_linear(self):
216-
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
216+
from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear
217217
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
218218

219219
group_size = 128
@@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self):
238238

239239
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
240240
def test_qat_8da4w_quantizer(self):
241-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
241+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
242242
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
243243

244244
group_size = 16
@@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self):
272272

273273
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
274274
def test_qat_8da4w_quantizer_meta_weights(self):
275-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
275+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
276276

277277
with torch.device("meta"):
278278
m = M()
@@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
287287
"""
288288
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
289289
"""
290-
from torchao.quantization.prototype.qat import (
290+
from torchao.quantization.qat.linear import (
291291
Int8DynActInt4WeightQATQuantizer,
292292
disable_8da4w_fake_quant,
293293
enable_8da4w_fake_quant,
@@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
346346
"""
347347
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
348348
"""
349-
from torchao.quantization.prototype.qat import (
349+
from torchao.quantization.qat.linear import (
350350
Int8DynActInt4WeightQATQuantizer,
351351
disable_8da4w_fake_quant,
352352
)
@@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer):
428428

429429
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
430430
def test_qat_8da4w_quantizer_gradients(self):
431-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
431+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
432432
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
433433
self._test_qat_quantized_gradients(quantizer)
434434

@@ -518,7 +518,7 @@ def test_qat_4w_primitives(self):
518518
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
519519
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
520520
def test_qat_4w_linear(self):
521-
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
521+
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
522522
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
523523

524524
group_size = 128
@@ -545,14 +545,14 @@ def test_qat_4w_linear(self):
545545

546546
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
547547
def test_qat_4w_quantizer_gradients(self):
548-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
548+
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
549549
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
550550
self._test_qat_quantized_gradients(quantizer)
551551

552552
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
553553
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
554554
def test_qat_4w_quantizer(self):
555-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
555+
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
556556
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
557557

558558
group_size = 32
@@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self):
630630

631631
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
632632
def test_qat_4w_embedding(self):
633-
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
633+
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
634634
model = M2()
635635
x = model.example_inputs()
636636
out = model(*x)

torchao/quantization/prototype/qat/_module_swap_api.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

torchao/quantization/prototype/qat/README.md renamed to torchao/quantization/qat/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ For example, on a single GPU:
4141
```python
4242
import torch
4343
from torchtune.models.llama3 import llama3
44-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
44+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
4545

4646
# Smaller version of llama3 to fit in a single GPU
4747
model = llama3(

torchao/quantization/prototype/qat/__init__.py renamed to torchao/quantization/qat/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,16 @@
22
ComposableQATQuantizer,
33
)
44
from .linear import (
5-
disable_4w_fake_quant,
6-
disable_8da4w_fake_quant,
7-
enable_4w_fake_quant,
8-
enable_8da4w_fake_quant,
95
Int4WeightOnlyQATQuantizer,
10-
Int8DynActInt4WeightQATLinear,
116
Int8DynActInt4WeightQATQuantizer,
127
)
138
from .embedding import (
149
Int4WeightOnlyEmbeddingQATQuantizer,
1510
)
1611

1712
__all__ = [
18-
"disable_4w_fake_quant",
19-
"disable_8da4w_fake_quant",
20-
"enable_4w_fake_quant",
21-
"enable_8da4w_fake_quant",
2213
"ComposableQATQuantizer",
2314
"Int4WeightOnlyQATQuantizer",
2415
"Int4WeightOnlyEmbeddingQATQuantizer"
2516
"Int8DynActInt4WeightQATQuantizer",
26-
"Int8DynActInt4WeightQATLinear",
2717
]

torchao/quantization/prototype/qat/utils.py renamed to torchao/quantization/qat/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def forward(
4646
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
4747
) -> torch.Tensor:
4848
# avoid circular dependencies
49-
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
49+
from torchao.quantization.qat.affine_fake_quantized_tensor import (
5050
AffineFakeQuantizedTensor,
5151
)
5252

@@ -88,7 +88,7 @@ def forward(
8888
input: torch.Tensor,
8989
) -> torch.Tensor:
9090
# avoid circular dependencies
91-
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
91+
from torchao.quantization.qat.affine_fake_quantized_tensor import (
9292
AffineFakeQuantizedTensor,
9393
)
9494
assert isinstance(input, AffineFakeQuantizedTensor)

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _replace_with_custom_fn_if_matches_filter(
220220

221221
def _is_linear(mod, *args):
222222
# avoid circular dependencies
223-
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
223+
from torchao.quantization.qat.affine_fake_quantized_tensor import (
224224
AffineFakeQuantizedTensor,
225225
)
226226

0 commit comments

Comments
 (0)