Skip to content

Commit 934dead

Browse files
authored
autoquant using aqt (#609)
* autoquant using aqt Summary: changing autoquant to use aqt instead of the old subclass subtensors changed aqt to first dispatch to a static _quantized_linear_op which then dispatches to the normal function. This way autoquant has an extention point to modify the kernel functions for various quantization modes without editing the main kernel function of all the classes. linear_activation_quantized_tensor got the same treatment. there were some transposes found in the aqt kernels not present in the subclass kernels, however they do not seen to affect performance (see benchmark_results.txt for an autoquant perf run) Test Plan: sh benchmarks.sh python test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 1cfe69e commit 934dead

File tree

4 files changed

+86
-24
lines changed

4 files changed

+86
-24
lines changed

test/integration/test_integration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,8 @@ def _test_lin_weight_subclass_impl(
672672
test_dtype=torch.bfloat16,
673673
test_shape=(32, 64, 32),
674674
):
675+
if not "cuda" in test_device:
676+
self.skipTest("test requires cuda")
675677
m, k, n = test_shape
676678
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
677679
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
@@ -709,30 +711,28 @@ def test_int8_weight_only_quant_subclass(self, device, dtype):
709711
)
710712

711713
@parameterized.expand(COMMON_DEVICE_DTYPE)
714+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
712715
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
713716
self._test_lin_weight_subclass_impl(
714717
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
715718
)
716719

717720
@parameterized.expand(COMMON_DEVICE_DTYPE)
718-
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
719-
self._test_lin_weight_subclass_impl(
720-
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
721-
)
722-
723-
@parameterized.expand(COMMON_DEVICE_DTYPE)
721+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
724722
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
725723
self._test_lin_weight_subclass_impl(
726724
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
727725
)
728726

729727
@parameterized.expand(COMMON_DEVICE_DTYPE)
728+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
730729
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
731730
self._test_lin_weight_subclass_impl(
732731
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
733732
)
734733

735734
@parameterized.expand(COMMON_DEVICE_DTYPE)
735+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
736736
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
737737
self._test_lin_weight_subclass_impl(
738738
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ def dequantize(self, output_dtype=None):
155155
int_data, scale, zero_point = self.layout_tensor.get_plain()
156156
return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)
157157

158+
@staticmethod
159+
def _quantized_linear_op(input_tensor, weight_tensor, bias):
160+
return _quantized_linear_op(input_tensor, weight_tensor, bias)
161+
158162
def __tensor_flatten__(self):
159163
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
160164

@@ -832,7 +836,7 @@ def _(func, types, args, kwargs):
832836
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
833837
# make the branches easier to understand in `_quantized_linear_op`
834838
try:
835-
return _quantized_linear_op(input_tensor, weight_tensor, bias)
839+
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
836840
except:
837841
if isinstance(input_tensor, AffineQuantizedTensor):
838842
input_tensor = input_tensor.dequantize()

torchao/quantization/autoquant.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import torch
22
import torchao
3+
from torchao.quantization.quant_primitives import (
4+
MappingType,
5+
ZeroPointDomain,
6+
)
37
from .subclass import ( # noqa
48
Int8DynamicallyQuantizedLinearWeight,
59
Int8WeightOnlyQuantizedLinearWeight,
610
QuantizedLinearWeightBase,
711
)
12+
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType
13+
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
814
from torch.utils._python_dispatch import return_and_correct_aliasing
915
from .quant_primitives import (
1016
safe_int_mm,
@@ -252,9 +258,9 @@ class AQMixin():
252258
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
253259
w_qtensor = cls.from_float(weight)
254260
if _is_interpolate_mode(mode):
255-
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs")
261+
q_c_op = torch.compile(cls._quantized_linear_op, mode="max-autotune-no-cudagraphs")
256262
else:
257-
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
263+
func = lambda a,b,c: F.relu(cls._quantized_linear_op(F.relu(a), b, c))
258264
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
259265
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
260266
if res < best_time*1.1:
@@ -263,10 +269,48 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
263269
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
264270
return res
265271

266-
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
272+
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
267273
"""
268274
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
269275
"""
276+
@classmethod
277+
def from_float(cls, weight):
278+
# TODO test if this is valid
279+
# in_features = weight.shape[1]
280+
# int8 dynamic quantization only has benefit when in_feature > 16
281+
# if in_features <= 16:
282+
# return weight
283+
284+
# avoid circular dep
285+
from torchao.dtypes import to_affine_quantized
286+
# weight settings
287+
mapping_type = MappingType.SYMMETRIC
288+
def get_weight_block_size(x):
289+
return (1, x.shape[1])
290+
target_dtype = torch.int8
291+
eps = torch.finfo(torch.float32).eps
292+
zero_point_dtype = torch.int64
293+
294+
# input settings
295+
def get_per_token_block_size(x):
296+
block_size = list(x.shape)
297+
for i in range(len(block_size)-1):
298+
block_size[i] = 1
299+
return block_size
300+
301+
input_mapping_type = MappingType.SYMMETRIC
302+
input_target_dtype = torch.int8
303+
input_eps = 1e-5
304+
input_quant_min = -127
305+
input_quant_max = 127
306+
layout_type = PlainLayoutType()
307+
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
308+
309+
block_size = get_weight_block_size(weight)
310+
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
311+
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
312+
return weight
313+
270314
@classmethod
271315
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
272316
"""
@@ -298,7 +342,8 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
298342
)
299343
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
300344
with torch.no_grad():
301-
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data)
345+
w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t()
346+
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8)
302347
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
303348

304349
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
@@ -313,18 +358,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
313358
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
314359
return res_f
315360

316-
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
361+
class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
317362
"""
318363
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
319364
"""
365+
@classmethod
366+
def from_float(cls, weight):
367+
mapping_type = MappingType.SYMMETRIC
368+
target_dtype = torch.int8
369+
eps = torch.finfo(torch.float32).eps
370+
zero_point_dtype = torch.int64
371+
block_size = (1, weight.shape[1])
372+
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
373+
320374

321-
class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
375+
class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin):
322376
"""
323377
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
324378
uses a different kernel
325379
"""
326380
@staticmethod
327-
def _quantized_op(act_mat, w_qtensor, bias):
381+
def _quantized_linear_op(act_mat, w_qtensor, bias):
328382
"""
329383
Performs the quantized linear operations
330384
@@ -339,8 +393,8 @@ def _quantized_op(act_mat, w_qtensor, bias):
339393
orig_dtype = act_mat.dtype
340394
orig_shape = act_mat.shape
341395
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
342-
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
343-
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales
396+
y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2)
397+
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale
344398
if bias is not None:
345399
y += bias
346400
return y.to(orig_dtype)
@@ -352,14 +406,14 @@ def _autoquant_test(cls, act_mat, *args):
352406
return torch.inf
353407
return super()._autoquant_test(act_mat, *args)
354408

355-
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
409+
class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin):
356410
"""
357411
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
358412
uses a different kernel
359413
"""
360-
def _quantized_op(act_mat, w_qtensor, bias):
414+
def _quantized_linear_op(act_mat, w_qtensor, bias):
361415
orig_shape = act_mat.shape
362-
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
416+
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)
363417
y=y.reshape(*orig_shape[:-1], y.shape[-1])
364418
if bias is not None:
365419
y += bias
@@ -377,7 +431,7 @@ def __init__(self):
377431
super().__init__()
378432

379433
@staticmethod
380-
def _quantized_op(act_mat, w_qtensor, bias):
434+
def _quantized_linear_op(act_mat, w_qtensor, bias):
381435
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
382436

383437
@classmethod

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ def __tensor_unflatten__(
5656
input_quant_func,
5757
)
5858

59+
@staticmethod
60+
def _quantized_linear_op(input_tensor, weight_tensor, bias):
61+
input_quant_func = weight_tensor.input_quant_func
62+
original_weight_tensor = weight_tensor.original_weight_tensor
63+
aqt = input_quant_func(input_tensor)
64+
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
65+
5966
@classmethod
6067
def from_float(cls, input_float, input_quant_func):
6168
return cls(input_float, input_quant_func)
@@ -101,10 +108,7 @@ def _(func, types, args, kwargs):
101108
args[2] if len(args) > 2 else None,
102109
)
103110
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
104-
input_quant_func = weight_tensor.input_quant_func
105-
original_weight_tensor = weight_tensor.original_weight_tensor
106-
aqt = input_quant_func(input_tensor)
107-
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
111+
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
108112

109113
raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op")
110114

0 commit comments

Comments
 (0)