1
1
import torch
2
2
import torchao
3
+ from torchao .quantization .quant_primitives import (
4
+ MappingType ,
5
+ ZeroPointDomain ,
6
+ )
3
7
from .subclass import ( # noqa
4
8
Int8DynamicallyQuantizedLinearWeight ,
5
9
Int8WeightOnlyQuantizedLinearWeight ,
6
10
QuantizedLinearWeightBase ,
7
11
)
12
+ from torchao .dtypes import AffineQuantizedTensor , PlainLayoutType
13
+ from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
8
14
from torch .utils ._python_dispatch import return_and_correct_aliasing
9
15
from .quant_primitives import (
10
16
safe_int_mm ,
@@ -252,9 +258,9 @@ class AQMixin():
252
258
def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
253
259
w_qtensor = cls .from_float (weight )
254
260
if _is_interpolate_mode (mode ):
255
- q_c_op = torch .compile (cls ._quantized_op , mode = "max-autotune-no-cudagraphs" )
261
+ q_c_op = torch .compile (cls ._quantized_linear_op , mode = "max-autotune-no-cudagraphs" )
256
262
else :
257
- func = lambda a ,b ,c : F .relu (cls ._quantized_op (F .relu (a ), b , c ))
263
+ func = lambda a ,b ,c : F .relu (cls ._quantized_linear_op (F .relu (a ), b , c ))
258
264
q_c_op = torch .compile (func , mode = "max-autotune-no-cudagraphs" )
259
265
res = do_autoquant_bench (q_c_op , act_mat , w_qtensor , bias , warmup = 25 , rep = 100 )
260
266
if res < best_time * 1.1 :
@@ -263,10 +269,48 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
263
269
print (f">>time: { res :0.3f} ms for { cls } , to_beat: { best_time :0.3f} ms " )
264
270
return res
265
271
266
- class AQInt8DynamicallyQuantizedLinearWeight (AQMixin , Int8DynamicallyQuantizedLinearWeight ):
272
+ class AQInt8DynamicallyQuantizedLinearWeight (AQMixin , LinearActivationQuantizedTensor ):
267
273
"""
268
274
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
269
275
"""
276
+ @classmethod
277
+ def from_float (cls , weight ):
278
+ # TODO test if this is valid
279
+ # in_features = weight.shape[1]
280
+ # int8 dynamic quantization only has benefit when in_feature > 16
281
+ # if in_features <= 16:
282
+ # return weight
283
+
284
+ # avoid circular dep
285
+ from torchao .dtypes import to_affine_quantized
286
+ # weight settings
287
+ mapping_type = MappingType .SYMMETRIC
288
+ def get_weight_block_size (x ):
289
+ return (1 , x .shape [1 ])
290
+ target_dtype = torch .int8
291
+ eps = torch .finfo (torch .float32 ).eps
292
+ zero_point_dtype = torch .int64
293
+
294
+ # input settings
295
+ def get_per_token_block_size (x ):
296
+ block_size = list (x .shape )
297
+ for i in range (len (block_size )- 1 ):
298
+ block_size [i ] = 1
299
+ return block_size
300
+
301
+ input_mapping_type = MappingType .SYMMETRIC
302
+ input_target_dtype = torch .int8
303
+ input_eps = 1e-5
304
+ input_quant_min = - 127
305
+ input_quant_max = 127
306
+ layout_type = PlainLayoutType ()
307
+ input_quant_func = lambda x : to_affine_quantized (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
308
+
309
+ block_size = get_weight_block_size (weight )
310
+ weight = to_affine_quantized (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , layout_type = layout_type )
311
+ weight = super (AQInt8DynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
312
+ return weight
313
+
270
314
@classmethod
271
315
def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
272
316
"""
@@ -298,7 +342,8 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
298
342
)
299
343
q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune-no-cudagraphs" )
300
344
with torch .no_grad ():
301
- res_matmul = do_autoquant_bench (q_c_matmul , x_vals_int8 , x_scales .reshape (- 1 ,1 ), w_qtensor .int_data )
345
+ w_vals_int8 = w_qtensor .original_weight_tensor .layout_tensor .int_data .contiguous ().t ()
346
+ res_matmul = do_autoquant_bench (q_c_matmul , x_vals_int8 , x_scales .reshape (- 1 ,1 ), w_vals_int8 )
302
347
print (f">>time: { res_matmul :0.3f} ms for { cls } matmul, to_beat: { best_time :0.3f} ms" )
303
348
304
349
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
@@ -313,18 +358,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
313
358
print (f">>time: { res_f :0.3f} ms for { cls } interpolated, breakeven constant: { max_int_const_win :0.2f} " )
314
359
return res_f
315
360
316
- class AQWeightOnlyQuantizedLinearWeight (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
361
+ class AQWeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
317
362
"""
318
363
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
319
364
"""
365
+ @classmethod
366
+ def from_float (cls , weight ):
367
+ mapping_type = MappingType .SYMMETRIC
368
+ target_dtype = torch .int8
369
+ eps = torch .finfo (torch .float32 ).eps
370
+ zero_point_dtype = torch .int64
371
+ block_size = (1 , weight .shape [1 ])
372
+ return super (AQWeightOnlyQuantizedLinearWeight , cls ).from_float (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
373
+
320
374
321
- class AQWeightOnlyQuantizedLinearWeight2 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
375
+ class AQWeightOnlyQuantizedLinearWeight2 (AQWeightOnlyQuantizedLinearWeight , AQMixin ):
322
376
"""
323
377
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
324
378
uses a different kernel
325
379
"""
326
380
@staticmethod
327
- def _quantized_op (act_mat , w_qtensor , bias ):
381
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
328
382
"""
329
383
Performs the quantized linear operations
330
384
@@ -339,8 +393,8 @@ def _quantized_op(act_mat, w_qtensor, bias):
339
393
orig_dtype = act_mat .dtype
340
394
orig_shape = act_mat .shape
341
395
act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ], 1 )
342
- y = (act_mat * w_qtensor .int_data .unsqueeze (0 )).sum (dim = - 2 )
343
- y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ]) * w_qtensor .q_scales
396
+ y = (act_mat * w_qtensor .layout_tensor . int_data . t () .unsqueeze (0 )).sum (dim = - 2 )
397
+ y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ]) * w_qtensor .layout_tensor . scale
344
398
if bias is not None :
345
399
y += bias
346
400
return y .to (orig_dtype )
@@ -352,14 +406,14 @@ def _autoquant_test(cls, act_mat, *args):
352
406
return torch .inf
353
407
return super ()._autoquant_test (act_mat , * args )
354
408
355
- class AQWeightOnlyQuantizedLinearWeight3 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
409
+ class AQWeightOnlyQuantizedLinearWeight3 (AQWeightOnlyQuantizedLinearWeight , AQMixin ):
356
410
"""
357
411
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
358
412
uses a different kernel
359
413
"""
360
- def _quantized_op (act_mat , w_qtensor , bias ):
414
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
361
415
orig_shape = act_mat .shape
362
- y = torch .mm (act_mat .reshape (- 1 , orig_shape [- 1 ]), w_qtensor .int_data * w_qtensor .q_scales )
416
+ y = torch .mm (act_mat .reshape (- 1 , orig_shape [- 1 ]), w_qtensor .layout_tensor . int_data . t () * w_qtensor .layout_tensor . scale )
363
417
y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ])
364
418
if bias is not None :
365
419
y += bias
@@ -377,7 +431,7 @@ def __init__(self):
377
431
super ().__init__ ()
378
432
379
433
@staticmethod
380
- def _quantized_op (act_mat , w_qtensor , bias ):
434
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
381
435
return torch .nn .functional .linear (act_mat , w_qtensor , bias )
382
436
383
437
@classmethod
0 commit comments