40
40
Int8DynActInt4WeightQATLinear ,
41
41
)
42
42
from torchao .quantization .qat .utils import (
43
- _choose_qparams_per_token_asymmetric ,
44
43
_fake_quantize_per_channel_group ,
45
44
_fake_quantize_per_token ,
46
45
_GenericFakeQuantize ,
53
52
MappingType ,
54
53
TorchAODType ,
55
54
ZeroPointDomain ,
55
+ choose_qparams_affine ,
56
+ dequantize_affine ,
56
57
fake_quantize_affine ,
58
+ quantize_affine ,
57
59
)
58
60
from torchao .quantization .unified import (
59
61
TwoStepQuantizer ,
60
62
)
61
63
from torchao .quantization .utils import (
64
+ _get_per_token_block_size ,
62
65
get_group_qparams_symmetric ,
63
66
get_groupwise_affine_qparams ,
64
67
groupwise_affine_quantize_tensor ,
@@ -134,12 +137,13 @@ def forward(self, x):
134
137
135
138
136
139
class M4 (torch .nn .Module ):
137
- def __init__ (self ):
140
+ def __init__ (self , dtype : torch . dtype = torch . float32 ):
138
141
super ().__init__ ()
139
- self .linear = torch .nn .Linear (512 , 256 , bias = False ).to (torch .float )
142
+ self .dtype = dtype
143
+ self .linear = torch .nn .Linear (512 , 256 , bias = False ).to (dtype )
140
144
141
145
def example_inputs (self ):
142
- return (torch .randn (1 , 512 ).to (torch . float ),)
146
+ return (torch .randn (1 , 512 ).to (self . dtype ),)
143
147
144
148
def forward (self , x ):
145
149
return self .linear (x )
@@ -219,30 +223,41 @@ def test_fake_quantize_per_token(self):
219
223
torch .manual_seed (self .SEED )
220
224
x = torch .randn (100 , 256 ).requires_grad_ ()
221
225
x2 = copy .deepcopy (x )
222
- # TODO: use torch.ops.aten.quantized_decomposed version instead
223
- (s , zp ) = _choose_qparams_per_token_asymmetric (x , torch .float32 , torch .int32 )
226
+ block_size = _get_per_token_block_size (x )
227
+ (s , zp ) = choose_qparams_affine (
228
+ x ,
229
+ mapping_type = MappingType .ASYMMETRIC ,
230
+ block_size = block_size ,
231
+ target_dtype = torch .int8 ,
232
+ quant_min = - 128 ,
233
+ quant_max = 127 ,
234
+ scale_dtype = torch .float32 ,
235
+ zero_point_dtype = torch .int32 ,
236
+ )
224
237
225
238
# fake quant op
226
239
out = _fake_quantize_per_token (x , s , zp , qmin , qmax )
227
240
out .sum ().backward ()
228
241
229
242
# compare against PTQ ops
230
- out_ptq = torch . ops . quantized_decomposed . quantize_per_token (
243
+ out_ptq = quantize_affine (
231
244
x2 ,
245
+ block_size ,
232
246
s ,
233
247
zp ,
248
+ torch .int8 ,
234
249
qmin ,
235
250
qmax ,
236
- torch .int8 ,
237
251
)
238
- out_ptq = torch . ops . quantized_decomposed . dequantize_per_token (
252
+ out_ptq = dequantize_affine (
239
253
out_ptq ,
254
+ block_size ,
240
255
s ,
241
256
zp ,
257
+ torch .int8 ,
242
258
qmin ,
243
259
qmax ,
244
- torch .int8 ,
245
- torch .float32 ,
260
+ output_dtype = torch .float32 ,
246
261
)
247
262
torch .testing .assert_close (out , out_ptq , atol = 0 , rtol = 0 )
248
263
@@ -1004,8 +1019,15 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
1004
1019
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
1005
1020
"""
1006
1021
# activations
1007
- (s , zp ) = _choose_qparams_per_token_asymmetric (
1008
- x , torch .float32 , torch .int32
1022
+ (s , zp ) = choose_qparams_affine (
1023
+ x ,
1024
+ mapping_type = MappingType .ASYMMETRIC ,
1025
+ block_size = _get_per_token_block_size (x ),
1026
+ target_dtype = torch .int8 ,
1027
+ quant_min = - 128 ,
1028
+ quant_max = 127 ,
1029
+ scale_dtype = torch .float32 ,
1030
+ zero_point_dtype = torch .int32 ,
1009
1031
)
1010
1032
(qmin , qmax ) = _get_qmin_qmax (8 )
1011
1033
x_fq = _fake_quantize_per_token (x , s , zp , qmin , qmax )
@@ -1427,10 +1449,7 @@ def test_qat_linear_bias(self):
1427
1449
example_inputs = m .example_inputs ()
1428
1450
m (* example_inputs )
1429
1451
1430
- @unittest .skipIf (
1431
- not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1432
- )
1433
- def test_fake_quantize_per_token_vs_convert (self ):
1452
+ def _test_fake_quantize_per_token_vs_convert (self , dtype : torch .dtype ):
1434
1453
"""
1435
1454
Test that the following produce the exact same numerics:
1436
1455
1. FakeQuantizer with asymmetric per_token config
@@ -1439,7 +1458,7 @@ def test_fake_quantize_per_token_vs_convert(self):
1439
1458
from torchao .quantization .utils import per_token_dynamic_quant
1440
1459
1441
1460
torch .manual_seed (self .SEED )
1442
- x = torch .randn (1 , 235 , 2048 )
1461
+ x = torch .randn (1 , 235 , 2048 ). to ( dtype )
1443
1462
config = FakeQuantizeConfig (torch .int8 , "per_token" , is_symmetric = False )
1444
1463
fake_quantizer = FakeQuantizer (config )
1445
1464
fake_quantizer_out = fake_quantizer (x )
@@ -1449,7 +1468,16 @@ def test_fake_quantize_per_token_vs_convert(self):
1449
1468
@unittest .skipIf (
1450
1469
not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1451
1470
)
1452
- def test_qat_8da4w_prepare_vs_convert (self ):
1471
+ def test_fake_quantize_per_token_vs_convert_fp32 (self ):
1472
+ self ._test_fake_quantize_per_token_vs_convert (torch .float32 )
1473
+
1474
+ @unittest .skipIf (
1475
+ not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1476
+ )
1477
+ def test_fake_quantize_per_token_vs_convert_bf16 (self ):
1478
+ self ._test_fake_quantize_per_token_vs_convert (torch .bfloat16 )
1479
+
1480
+ def _test_qat_8da4w_prepare_vs_convert (self , dtype : torch .dtype ):
1453
1481
"""
1454
1482
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
1455
1483
numerics that match exactly over N trials.
@@ -1463,7 +1491,7 @@ def test_qat_8da4w_prepare_vs_convert(self):
1463
1491
1464
1492
for seed in range (self .SEED , self .SEED + num_trials ):
1465
1493
torch .manual_seed (seed )
1466
- m = M4 ()
1494
+ m = M4 (dtype )
1467
1495
torch .manual_seed (seed )
1468
1496
x = m .example_inputs ()
1469
1497
@@ -1486,6 +1514,18 @@ def test_qat_8da4w_prepare_vs_convert(self):
1486
1514
)
1487
1515
self .assertEqual (len (non_inf_sqnr ), 0 , fail_message )
1488
1516
1517
+ @unittest .skipIf (
1518
+ not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1519
+ )
1520
+ def test_qat_8da4w_prepare_vs_convert_fp32 (self ):
1521
+ self ._test_qat_8da4w_prepare_vs_convert (torch .float32 )
1522
+
1523
+ @unittest .skipIf (
1524
+ not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1525
+ )
1526
+ def test_qat_8da4w_prepare_vs_convert_bf16 (self ):
1527
+ self ._test_qat_8da4w_prepare_vs_convert (torch .bfloat16 )
1528
+
1489
1529
1490
1530
if __name__ == "__main__" :
1491
1531
unittest .main ()
0 commit comments