Skip to content

autoquant using aqt #609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def _test_lin_weight_subclass_impl(
test_dtype=torch.bfloat16,
test_shape=(32, 64, 32),
):
if not "cuda" in test_device:
self.skipTest("test requires cuda")
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
Expand Down Expand Up @@ -709,30 +711,28 @@ def test_int8_weight_only_quant_subclass(self, device, dtype):
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
Expand Down
6 changes: 5 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def dequantize(self, output_dtype=None):
int_data, scale, zero_point = self.layout_tensor.get_plain()
return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

@staticmethod
def _quantized_linear_op(input_tensor, weight_tensor, bias):
return _quantized_linear_op(input_tensor, weight_tensor, bias)

def __tensor_flatten__(self):
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

Expand Down Expand Up @@ -832,7 +836,7 @@ def _(func, types, args, kwargs):
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand Down
80 changes: 67 additions & 13 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import torch
import torchao
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
safe_int_mm,
Expand Down Expand Up @@ -252,9 +258,9 @@ class AQMixin():
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
w_qtensor = cls.from_float(weight)
if _is_interpolate_mode(mode):
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs")
q_c_op = torch.compile(cls._quantized_linear_op, mode="max-autotune-no-cudagraphs")
else:
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
func = lambda a,b,c: F.relu(cls._quantized_linear_op(F.relu(a), b, c))
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
if res < best_time*1.1:
Expand All @@ -263,10 +269,48 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
return res

class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
# TODO test if this is valid
# in_features = weight.shape[1]
# int8 dynamic quantization only has benefit when in_feature > 16
# if in_features <= 16:
# return weight
Comment on lines +279 to +282
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these? should these be enabled or removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a todo


# avoid circular dep
from torchao.dtypes import to_affine_quantized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the code from quant_api.py? also we need to refactor the input_quant_func to be a normal function (not lambda) in order for serialization to work I think, might help to do that refactor at the same time

Copy link
Contributor Author

@HDCharles HDCharles Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue is that we need to call from_float with super in order to get this to work correctly. The code in quant_api would generate an aqt, not an autoquant class inheriting from an aqt. (I tried this approach initially since that's how it worked with subclass)

If we want to reuse the code, it may make sense to make a function that like prepares all the variables needed to go into from_float

def int8_weight_only_kwargs(weight):
       ...do the code up to to_affine_quantized and put it all together
       return a_bunch_of_kwargs
       
then you could have the quant_api code be like

def int8_weight_only():
    def apply_int8wo_quant(weight):
        kwargs = int8_weight_only_kwargs(weight)   
        return to_affine_quantized(**kwargs)
    return _get_linear_subclass_inserter(apply_int8wo_quant)

then in autoquant we could do similar

def from_float(weight):
      kwargs = int8_weight_only_kwargs(weight)
      super().from_float(**kwargs)
      

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does move apply_int8wo_quant function to top level help? I'm doing it here: https://github.com/pytorch/ao/pull/630/files, seems like you call the super().from_float after the aqt is produced right

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move the apply quant to weight Tensor function to top level as well I think

# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
layout_type = PlainLayoutType()
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight

@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
"""
Expand Down Expand Up @@ -298,7 +342,8 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data)
w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t()
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8)
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
Expand All @@ -313,18 +358,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
return res_f

class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)


class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
"""
Performs the quantized linear operations

Expand All @@ -339,8 +393,8 @@ def _quantized_op(act_mat, w_qtensor, bias):
orig_dtype = act_mat.dtype
orig_shape = act_mat.shape
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales
y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale
if bias is not None:
y += bias
return y.to(orig_dtype)
Expand All @@ -352,14 +406,14 @@ def _autoquant_test(cls, act_mat, *args):
return torch.inf
return super()._autoquant_test(act_mat, *args)

class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)
y=y.reshape(*orig_shape[:-1], y.shape[-1])
if bias is not None:
y += bias
Expand All @@ -377,7 +431,7 @@ def __init__(self):
super().__init__()

@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor, bias)

@classmethod
Expand Down
12 changes: 8 additions & 4 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __tensor_unflatten__(
input_quant_func,
)

@staticmethod
def _quantized_linear_op(input_tensor, weight_tensor, bias):
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)

@classmethod
def from_float(cls, input_float, input_quant_func):
return cls(input_float, input_quant_func)
Expand Down Expand Up @@ -101,10 +108,7 @@ def _(func, types, args, kwargs):
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)

raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op")

Expand Down
Loading