22
22
PerRow ,
23
23
PerToken ,
24
24
)
25
- from torchao .quantization .prototype . qat .api import (
25
+ from torchao .quantization .qat .api import (
26
26
ComposableQATQuantizer ,
27
27
FakeQuantizeConfig ,
28
28
)
29
- from torchao .quantization .prototype . qat .fake_quantizer import (
29
+ from torchao .quantization .qat .fake_quantizer import (
30
30
FakeQuantizer ,
31
31
)
32
- from torchao .quantization .prototype . qat .embedding import (
32
+ from torchao .quantization .qat .embedding import (
33
33
FakeQuantizedEmbedding ,
34
34
)
35
- from torchao .quantization .prototype . qat .linear import (
35
+ from torchao .quantization .qat .linear import (
36
36
FakeQuantizedLinear ,
37
37
Int8DynActInt4WeightQATLinear ,
38
38
Int4WeightOnlyQATLinear
39
39
)
40
- from torchao .quantization .prototype . qat .utils import (
40
+ from torchao .quantization .qat .utils import (
41
41
_choose_qparams_per_token_asymmetric ,
42
42
_fake_quantize_per_channel_group ,
43
43
_fake_quantize_per_token ,
@@ -181,7 +181,7 @@ def _set_ptq_weight(
181
181
Int8DynActInt4WeightLinear ,
182
182
WeightOnlyInt4Linear ,
183
183
)
184
- from torchao .quantization .prototype . qat .linear import (
184
+ from torchao .quantization .qat .linear import (
185
185
Int8DynActInt4WeightQATLinear ,
186
186
Int4WeightOnlyQATLinear ,
187
187
)
@@ -213,7 +213,7 @@ def _set_ptq_weight(
213
213
214
214
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
215
215
def test_qat_8da4w_linear (self ):
216
- from torchao .quantization .prototype . qat .linear import Int8DynActInt4WeightQATLinear
216
+ from torchao .quantization .qat .linear import Int8DynActInt4WeightQATLinear
217
217
from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
218
218
219
219
group_size = 128
@@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self):
238
238
239
239
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
240
240
def test_qat_8da4w_quantizer (self ):
241
- from torchao .quantization .prototype . qat import Int8DynActInt4WeightQATQuantizer
241
+ from torchao .quantization .qat import Int8DynActInt4WeightQATQuantizer
242
242
from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
243
243
244
244
group_size = 16
@@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self):
272
272
273
273
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
274
274
def test_qat_8da4w_quantizer_meta_weights (self ):
275
- from torchao .quantization .prototype . qat import Int8DynActInt4WeightQATQuantizer
275
+ from torchao .quantization .qat import Int8DynActInt4WeightQATQuantizer
276
276
277
277
with torch .device ("meta" ):
278
278
m = M ()
@@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
287
287
"""
288
288
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
289
289
"""
290
- from torchao .quantization .prototype . qat import (
290
+ from torchao .quantization .qat . linear import (
291
291
Int8DynActInt4WeightQATQuantizer ,
292
292
disable_8da4w_fake_quant ,
293
293
enable_8da4w_fake_quant ,
@@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
346
346
"""
347
347
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
348
348
"""
349
- from torchao .quantization .prototype . qat import (
349
+ from torchao .quantization .qat . linear import (
350
350
Int8DynActInt4WeightQATQuantizer ,
351
351
disable_8da4w_fake_quant ,
352
352
)
@@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer):
428
428
429
429
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
430
430
def test_qat_8da4w_quantizer_gradients (self ):
431
- from torchao .quantization .prototype . qat import Int8DynActInt4WeightQATQuantizer
431
+ from torchao .quantization .qat import Int8DynActInt4WeightQATQuantizer
432
432
quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = 16 )
433
433
self ._test_qat_quantized_gradients (quantizer )
434
434
@@ -518,7 +518,7 @@ def test_qat_4w_primitives(self):
518
518
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
519
519
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
520
520
def test_qat_4w_linear (self ):
521
- from torchao .quantization .prototype . qat .linear import Int4WeightOnlyQATLinear
521
+ from torchao .quantization .qat .linear import Int4WeightOnlyQATLinear
522
522
from torchao .quantization .GPTQ import WeightOnlyInt4Linear
523
523
524
524
group_size = 128
@@ -545,14 +545,14 @@ def test_qat_4w_linear(self):
545
545
546
546
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
547
547
def test_qat_4w_quantizer_gradients (self ):
548
- from torchao .quantization .prototype . qat import Int4WeightOnlyQATQuantizer
548
+ from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
549
549
quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
550
550
self ._test_qat_quantized_gradients (quantizer )
551
551
552
552
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
553
553
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
554
554
def test_qat_4w_quantizer (self ):
555
- from torchao .quantization .prototype . qat import Int4WeightOnlyQATQuantizer
555
+ from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
556
556
from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
557
557
558
558
group_size = 32
@@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self):
630
630
631
631
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
632
632
def test_qat_4w_embedding (self ):
633
- from torchao .quantization .prototype . qat import Int4WeightOnlyEmbeddingQATQuantizer
633
+ from torchao .quantization .qat import Int4WeightOnlyEmbeddingQATQuantizer
634
634
model = M2 ()
635
635
x = model .example_inputs ()
636
636
out = model (* x )
0 commit comments