22
22
PerRow ,
23
23
PerToken ,
24
24
)
25
- from torchao .quantization . prototype .qat .api import (
25
+ from torchao .prototype . quantization .qat .api import (
26
26
ComposableQATQuantizer ,
27
27
FakeQuantizeConfig ,
28
28
)
29
- from torchao .quantization . prototype .qat .fake_quantizer import (
29
+ from torchao .prototype . quantization .qat .fake_quantizer import (
30
30
FakeQuantizer ,
31
31
)
32
- from torchao .quantization . prototype .qat .linear import (
32
+ from torchao .prototype . quantization .qat .linear import (
33
33
FakeQuantizedLinear ,
34
34
)
35
- from torchao .quantization . prototype .qat .utils import (
35
+ from torchao .prototype . quantization .qat .utils import (
36
36
_choose_qparams_per_token_asymmetric ,
37
37
_fake_quantize_per_channel_group ,
38
38
_fake_quantize_per_token ,
@@ -172,7 +172,7 @@ def _set_ptq_weight(
172
172
Int8DynActInt4WeightLinear ,
173
173
WeightOnlyInt4Linear ,
174
174
)
175
- from torchao .quantization . prototype .qat .linear import (
175
+ from torchao .prototype . quantization .qat .linear import (
176
176
Int8DynActInt4WeightQATLinear ,
177
177
Int4WeightOnlyQATLinear ,
178
178
)
@@ -204,7 +204,7 @@ def _set_ptq_weight(
204
204
205
205
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
206
206
def test_qat_8da4w_linear (self ):
207
- from torchao .quantization . prototype .qat .linear import Int8DynActInt4WeightQATLinear
207
+ from torchao .prototype . quantization .qat .linear import Int8DynActInt4WeightQATLinear
208
208
from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
209
209
210
210
group_size = 128
@@ -229,7 +229,7 @@ def test_qat_8da4w_linear(self):
229
229
230
230
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
231
231
def test_qat_8da4w_quantizer (self ):
232
- from torchao .quantization . prototype .qat import Int8DynActInt4WeightQATQuantizer
232
+ from torchao .prototype . quantization .qat import Int8DynActInt4WeightQATQuantizer
233
233
from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
234
234
235
235
group_size = 16
@@ -263,7 +263,7 @@ def test_qat_8da4w_quantizer(self):
263
263
264
264
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
265
265
def test_qat_8da4w_quantizer_meta_weights (self ):
266
- from torchao .quantization . prototype .qat import Int8DynActInt4WeightQATQuantizer
266
+ from torchao .prototype . quantization .qat import Int8DynActInt4WeightQATQuantizer
267
267
268
268
with torch .device ("meta" ):
269
269
m = M ()
@@ -278,7 +278,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
278
278
"""
279
279
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
280
280
"""
281
- from torchao .quantization . prototype .qat import (
281
+ from torchao .prototype . quantization .qat import (
282
282
Int8DynActInt4WeightQATQuantizer ,
283
283
disable_8da4w_fake_quant ,
284
284
enable_8da4w_fake_quant ,
@@ -337,7 +337,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
337
337
"""
338
338
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
339
339
"""
340
- from torchao .quantization . prototype .qat import (
340
+ from torchao .prototype . quantization .qat import (
341
341
Int8DynActInt4WeightQATQuantizer ,
342
342
disable_8da4w_fake_quant ,
343
343
)
@@ -419,7 +419,7 @@ def _test_qat_quantized_gradients(self, quantizer):
419
419
420
420
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
421
421
def test_qat_8da4w_quantizer_gradients (self ):
422
- from torchao .quantization . prototype .qat import Int8DynActInt4WeightQATQuantizer
422
+ from torchao .prototype . quantization .qat import Int8DynActInt4WeightQATQuantizer
423
423
quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = 16 )
424
424
self ._test_qat_quantized_gradients (quantizer )
425
425
@@ -509,7 +509,7 @@ def test_qat_4w_primitives(self):
509
509
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
510
510
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
511
511
def test_qat_4w_linear (self ):
512
- from torchao .quantization . prototype .qat .linear import Int4WeightOnlyQATLinear
512
+ from torchao .prototype . quantization .qat .linear import Int4WeightOnlyQATLinear
513
513
from torchao .quantization .GPTQ import WeightOnlyInt4Linear
514
514
515
515
group_size = 128
@@ -536,14 +536,14 @@ def test_qat_4w_linear(self):
536
536
537
537
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
538
538
def test_qat_4w_quantizer_gradients (self ):
539
- from torchao .quantization . prototype .qat import Int4WeightOnlyQATQuantizer
539
+ from torchao .prototype . quantization .qat import Int4WeightOnlyQATQuantizer
540
540
quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
541
541
self ._test_qat_quantized_gradients (quantizer )
542
542
543
543
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
544
544
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
545
545
def test_qat_4w_quantizer (self ):
546
- from torchao .quantization . prototype .qat import Int4WeightOnlyQATQuantizer
546
+ from torchao .prototype . quantization .qat import Int4WeightOnlyQATQuantizer
547
547
from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
548
548
549
549
group_size = 32
@@ -621,7 +621,7 @@ def test_composable_qat_quantizer(self):
621
621
622
622
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
623
623
def test_qat_4w_embedding (self ):
624
- from torchao .quantization . prototype .qat import Int4WeightOnlyEmbeddingQATQuantizer
624
+ from torchao .prototype . quantization .qat import Int4WeightOnlyEmbeddingQATQuantizer
625
625
model = M2 ()
626
626
x = model .example_inputs ()
627
627
out = model (* x )
0 commit comments