1
1
import torch
2
2
import torchao
3
+ from torchao .quantization .quant_primitives import (
4
+ MappingType ,
5
+ ZeroPointDomain ,
6
+ )
3
7
from .subclass import ( # noqa
4
8
Int8DynamicallyQuantizedLinearWeight ,
5
9
Int8WeightOnlyQuantizedLinearWeight ,
6
10
QuantizedLinearWeightBase ,
7
11
)
12
+ from torchao .dtypes import AffineQuantizedTensor , PlainLayoutType
13
+ from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
8
14
from torch .utils ._python_dispatch import return_and_correct_aliasing
9
15
from .quant_primitives import (
10
16
safe_int_mm ,
@@ -252,9 +258,9 @@ class AQMixin():
252
258
def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
253
259
w_qtensor = cls .from_float (weight )
254
260
if _is_interpolate_mode (mode ):
255
- q_c_op = torch .compile (cls ._quantized_op , mode = "max-autotune-no-cudagraphs" )
261
+ q_c_op = torch .compile (cls ._quantized_linear_op , mode = "max-autotune-no-cudagraphs" )
256
262
else :
257
- func = lambda a ,b ,c : F .relu (cls ._quantized_op (F .relu (a ), b , c ))
263
+ func = lambda a ,b ,c : F .relu (cls ._quantized_linear_op (F .relu (a ), b , c ))
258
264
q_c_op = torch .compile (func , mode = "max-autotune-no-cudagraphs" )
259
265
res = do_autoquant_bench (q_c_op , act_mat , w_qtensor , bias , warmup = 25 , rep = 100 )
260
266
if res < best_time * 1.1 :
@@ -263,10 +269,53 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
263
269
print (f">>time: { res :0.3f} ms for { cls } , to_beat: { best_time :0.3f} ms " )
264
270
return res
265
271
266
- class AQInt8DynamicallyQuantizedLinearWeight (AQMixin , Int8DynamicallyQuantizedLinearWeight ):
272
+ ###### TODO !!!!!!!!!!!!!!!
273
+ # 1) make class method from_float (just duplicate code)
274
+ # 2) undo changes to quant_api?
275
+ # 3) point to new quantized_op location
276
+ # 4) rewrite the dynamic autoquant test
277
+
278
+ class AQInt8DynamicallyQuantizedLinearWeight (AQMixin , LinearActivationQuantizedTensor ):
267
279
"""
268
280
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
269
281
"""
282
+ @classmethod
283
+ def from_float (cls , weight ):
284
+ in_features = weight .shape [1 ]
285
+ # int8 dynamic quantization only has benefit when in_feature > 16
286
+ # if in_features <= 16:
287
+ # return weight
288
+
289
+ # avoid circular dep
290
+ from torchao .dtypes import to_affine_quantized
291
+ # weight settings
292
+ mapping_type = MappingType .SYMMETRIC
293
+ def get_weight_block_size (x ):
294
+ return (1 , x .shape [1 ])
295
+ target_dtype = torch .int8
296
+ eps = torch .finfo (torch .float32 ).eps
297
+ zero_point_dtype = torch .int64
298
+
299
+ # input settings
300
+ def get_per_token_block_size (x ):
301
+ block_size = list (x .shape )
302
+ for i in range (len (block_size )- 1 ):
303
+ block_size [i ] = 1
304
+ return block_size
305
+
306
+ input_mapping_type = MappingType .SYMMETRIC
307
+ input_target_dtype = torch .int8
308
+ input_eps = 1e-5
309
+ input_quant_min = - 127
310
+ input_quant_max = 127
311
+ layout_type = PlainLayoutType ()
312
+ input_quant_func = lambda x : to_affine_quantized (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
313
+
314
+ block_size = get_weight_block_size (weight )
315
+ weight = to_affine_quantized (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , layout_type = layout_type )
316
+ weight = super (AQInt8DynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
317
+ return weight
318
+
270
319
@classmethod
271
320
def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
272
321
"""
@@ -298,12 +347,13 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
298
347
)
299
348
q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune-no-cudagraphs" )
300
349
with torch .no_grad ():
301
- res_matmul = do_autoquant_bench (q_c_matmul , x_vals_int8 , x_scales .reshape (- 1 ,1 ), w_qtensor .int_data )
350
+ w_vals_int8 = w_qtensor .original_weight_tensor .layout_tensor .int_data .contiguous ().t ()
351
+ res_matmul = do_autoquant_bench (q_c_matmul , x_vals_int8 , x_scales .reshape (- 1 ,1 ), w_vals_int8 )
302
352
print (f">>time: { res_matmul :0.3f} ms for { cls } matmul, to_beat: { best_time :0.3f} ms" )
303
353
304
354
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
305
- if res_matmul >= best_time :
306
- return res_matmul
355
+ # if res_matmul>=best_time:
356
+ # return res_matmul
307
357
308
358
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
309
359
to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT )* (best_time - res_matmul )
@@ -313,18 +363,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
313
363
print (f">>time: { res_f :0.3f} ms for { cls } interpolated, breakeven constant: { max_int_const_win :0.2f} " )
314
364
return res_f
315
365
316
- class AQWeightOnlyQuantizedLinearWeight (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
366
+ class AQWeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
317
367
"""
318
368
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
319
369
"""
370
+ @classmethod
371
+ def from_float (cls , weight ):
372
+ mapping_type = MappingType .SYMMETRIC
373
+ target_dtype = torch .int8
374
+ eps = torch .finfo (torch .float32 ).eps
375
+ zero_point_dtype = torch .int64
376
+ block_size = (1 , weight .shape [1 ])
377
+ return super (AQWeightOnlyQuantizedLinearWeight , cls ).from_float (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
378
+
320
379
321
- class AQWeightOnlyQuantizedLinearWeight2 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
380
+ class AQWeightOnlyQuantizedLinearWeight2 (AQWeightOnlyQuantizedLinearWeight , AQMixin ):
322
381
"""
323
382
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
324
383
uses a different kernel
325
384
"""
326
385
@staticmethod
327
- def _quantized_op (act_mat , w_qtensor , bias ):
386
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
328
387
"""
329
388
Performs the quantized linear operations
330
389
@@ -339,8 +398,8 @@ def _quantized_op(act_mat, w_qtensor, bias):
339
398
orig_dtype = act_mat .dtype
340
399
orig_shape = act_mat .shape
341
400
act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ], 1 )
342
- y = (act_mat * w_qtensor .int_data .unsqueeze (0 )).sum (dim = - 2 )
343
- y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ]) * w_qtensor .q_scales
401
+ y = (act_mat * w_qtensor .layout_tensor . int_data . t () .unsqueeze (0 )).sum (dim = - 2 )
402
+ y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ]) * w_qtensor .layout_tensor . scale
344
403
if bias is not None :
345
404
y += bias
346
405
return y .to (orig_dtype )
@@ -352,14 +411,14 @@ def _autoquant_test(cls, act_mat, *args):
352
411
return torch .inf
353
412
return super ()._autoquant_test (act_mat , * args )
354
413
355
- class AQWeightOnlyQuantizedLinearWeight3 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
414
+ class AQWeightOnlyQuantizedLinearWeight3 (AQWeightOnlyQuantizedLinearWeight , AQMixin ):
356
415
"""
357
416
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
358
417
uses a different kernel
359
418
"""
360
- def _quantized_op (act_mat , w_qtensor , bias ):
419
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
361
420
orig_shape = act_mat .shape
362
- y = torch .mm (act_mat .reshape (- 1 , orig_shape [- 1 ]), w_qtensor .int_data * w_qtensor .q_scales )
421
+ y = torch .mm (act_mat .reshape (- 1 , orig_shape [- 1 ]), w_qtensor .layout_tensor . int_data . t () * w_qtensor .layout_tensor . scale )
363
422
y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ])
364
423
if bias is not None :
365
424
y += bias
@@ -377,7 +436,7 @@ def __init__(self):
377
436
super ().__init__ ()
378
437
379
438
@staticmethod
380
- def _quantized_op (act_mat , w_qtensor , bias ):
439
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
381
440
return torch .nn .functional .linear (act_mat , w_qtensor , bias )
382
441
383
442
@classmethod
@@ -389,7 +448,7 @@ def from_float(cls, weight):
389
448
AQWeightOnlyQuantizedLinearWeight ,
390
449
AQWeightOnlyQuantizedLinearWeight2 ,
391
450
# AQWeightOnlyQuantizedLinearWeight3,
392
- # TODO this gets picked in places where it makes perf worse, why?
451
+ # # TODO this gets picked in places where it makes perf worse, why?
393
452
AQInt8DynamicallyQuantizedLinearWeight ,
394
453
]
395
454
0 commit comments