Skip to content

Commit 85c7332

Browse files
committed
move files from quantization/prototype -> prototype/quantization
1 parent 7a35695 commit 85c7332

30 files changed

+21
-21
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
5959
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
6060

6161
```python
62-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
62+
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer
6363

6464
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
6565

test/quantization/test_mixed_precision.py renamed to test/prototype/test_mixed_precision.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
66
from torchao.quantization.utils import compute_error
7-
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only
7+
from torchao.prototype.quantization.mixed_precision.scripts.naive_intNwo import intN_weight_only
88

99
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
1010

test/quantization/test_qat.py renamed to test/prototype/test_qat.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@
2222
PerRow,
2323
PerToken,
2424
)
25-
from torchao.quantization.prototype.qat.api import (
25+
from torchao.prototype.quantization.qat.api import (
2626
ComposableQATQuantizer,
2727
FakeQuantizeConfig,
2828
)
29-
from torchao.quantization.prototype.qat.fake_quantizer import (
29+
from torchao.prototype.quantization.qat.fake_quantizer import (
3030
FakeQuantizer,
3131
)
32-
from torchao.quantization.prototype.qat.linear import (
32+
from torchao.prototype.quantization.qat.linear import (
3333
FakeQuantizedLinear,
3434
)
35-
from torchao.quantization.prototype.qat.utils import (
35+
from torchao.prototype.quantization.qat.utils import (
3636
_choose_qparams_per_token_asymmetric,
3737
_fake_quantize_per_channel_group,
3838
_fake_quantize_per_token,
@@ -172,7 +172,7 @@ def _set_ptq_weight(
172172
Int8DynActInt4WeightLinear,
173173
WeightOnlyInt4Linear,
174174
)
175-
from torchao.quantization.prototype.qat.linear import (
175+
from torchao.prototype.quantization.qat.linear import (
176176
Int8DynActInt4WeightQATLinear,
177177
Int4WeightOnlyQATLinear,
178178
)
@@ -204,7 +204,7 @@ def _set_ptq_weight(
204204

205205
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
206206
def test_qat_8da4w_linear(self):
207-
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
207+
from torchao.prototype.quantization.qat.linear import Int8DynActInt4WeightQATLinear
208208
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
209209

210210
group_size = 128
@@ -229,7 +229,7 @@ def test_qat_8da4w_linear(self):
229229

230230
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
231231
def test_qat_8da4w_quantizer(self):
232-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
232+
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer
233233
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
234234

235235
group_size = 16
@@ -263,7 +263,7 @@ def test_qat_8da4w_quantizer(self):
263263

264264
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
265265
def test_qat_8da4w_quantizer_meta_weights(self):
266-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
266+
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer
267267

268268
with torch.device("meta"):
269269
m = M()
@@ -278,7 +278,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
278278
"""
279279
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
280280
"""
281-
from torchao.quantization.prototype.qat import (
281+
from torchao.prototype.quantization.qat import (
282282
Int8DynActInt4WeightQATQuantizer,
283283
disable_8da4w_fake_quant,
284284
enable_8da4w_fake_quant,
@@ -337,7 +337,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
337337
"""
338338
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
339339
"""
340-
from torchao.quantization.prototype.qat import (
340+
from torchao.prototype.quantization.qat import (
341341
Int8DynActInt4WeightQATQuantizer,
342342
disable_8da4w_fake_quant,
343343
)
@@ -419,7 +419,7 @@ def _test_qat_quantized_gradients(self, quantizer):
419419

420420
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
421421
def test_qat_8da4w_quantizer_gradients(self):
422-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
422+
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer
423423
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
424424
self._test_qat_quantized_gradients(quantizer)
425425

@@ -509,7 +509,7 @@ def test_qat_4w_primitives(self):
509509
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
510510
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
511511
def test_qat_4w_linear(self):
512-
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
512+
from torchao.prototype.quantization.qat.linear import Int4WeightOnlyQATLinear
513513
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
514514

515515
group_size = 128
@@ -536,14 +536,14 @@ def test_qat_4w_linear(self):
536536

537537
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
538538
def test_qat_4w_quantizer_gradients(self):
539-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
539+
from torchao.prototype.quantization.qat import Int4WeightOnlyQATQuantizer
540540
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
541541
self._test_qat_quantized_gradients(quantizer)
542542

543543
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
544544
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
545545
def test_qat_4w_quantizer(self):
546-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
546+
from torchao.prototype.quantization.qat import Int4WeightOnlyQATQuantizer
547547
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
548548

549549
group_size = 32
@@ -621,7 +621,7 @@ def test_composable_qat_quantizer(self):
621621

622622
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
623623
def test_qat_4w_embedding(self):
624-
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
624+
from torchao.prototype.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
625625
model = M2()
626626
x = model.example_inputs()
627627
out = model(*x)

torchao/quantization/prototype/qat/README.md renamed to torchao/prototype/quantization/qat/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ For example, on a single GPU:
4141
```python
4242
import torch
4343
from torchtune.models.llama3 import llama3
44-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
44+
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer
4545

4646
# Smaller version of llama3 to fit in a single GPU
4747
model = llama3(

torchao/quantization/prototype/qat/utils.py renamed to torchao/prototype/quantization/qat/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def forward(
4646
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
4747
) -> torch.Tensor:
4848
# avoid circular dependencies
49-
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
49+
from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import (
5050
AffineFakeQuantizedTensor,
5151
)
5252

@@ -88,7 +88,7 @@ def forward(
8888
input: torch.Tensor,
8989
) -> torch.Tensor:
9090
# avoid circular dependencies
91-
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
91+
from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import (
9292
AffineFakeQuantizedTensor,
9393
)
9494
assert isinstance(input, AffineFakeQuantizedTensor)

torchao/quantization/quant_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _replace_with_custom_fn_if_matches_filter(
220220

221221
def _is_linear(mod, *args):
222222
# avoid circular dependencies
223-
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
223+
from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import (
224224
AffineFakeQuantizedTensor,
225225
)
226226

0 commit comments

Comments
 (0)