-
Notifications
You must be signed in to change notification settings - Fork 262
autoquant using aqt #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
autoquant using aqt #609
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,16 @@ | ||
import torch | ||
import torchao | ||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
from .subclass import ( # noqa | ||
Int8DynamicallyQuantizedLinearWeight, | ||
Int8WeightOnlyQuantizedLinearWeight, | ||
QuantizedLinearWeightBase, | ||
) | ||
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType | ||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor | ||
from torch.utils._python_dispatch import return_and_correct_aliasing | ||
from .quant_primitives import ( | ||
safe_int_mm, | ||
|
@@ -252,9 +258,9 @@ class AQMixin(): | |
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): | ||
w_qtensor = cls.from_float(weight) | ||
if _is_interpolate_mode(mode): | ||
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs") | ||
q_c_op = torch.compile(cls._quantized_linear_op, mode="max-autotune-no-cudagraphs") | ||
else: | ||
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) | ||
func = lambda a,b,c: F.relu(cls._quantized_linear_op(F.relu(a), b, c)) | ||
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") | ||
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) | ||
if res < best_time*1.1: | ||
|
@@ -263,10 +269,48 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): | |
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") | ||
return res | ||
|
||
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): | ||
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): | ||
""" | ||
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight | ||
""" | ||
@classmethod | ||
def from_float(cls, weight): | ||
# TODO test if this is valid | ||
# in_features = weight.shape[1] | ||
# int8 dynamic quantization only has benefit when in_feature > 16 | ||
# if in_features <= 16: | ||
# return weight | ||
|
||
# avoid circular dep | ||
from torchao.dtypes import to_affine_quantized | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we reuse the code from quant_api.py? also we need to refactor the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the issue is that we need to call from_float with super in order to get this to work correctly. The code in quant_api would generate an aqt, not an autoquant class inheriting from an aqt. (I tried this approach initially since that's how it worked with subclass) If we want to reuse the code, it may make sense to make a function that like prepares all the variables needed to go into from_float
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can move the apply quant to weight Tensor function to top level as well I think |
||
# weight settings | ||
mapping_type = MappingType.SYMMETRIC | ||
def get_weight_block_size(x): | ||
return (1, x.shape[1]) | ||
target_dtype = torch.int8 | ||
eps = torch.finfo(torch.float32).eps | ||
zero_point_dtype = torch.int64 | ||
|
||
# input settings | ||
def get_per_token_block_size(x): | ||
block_size = list(x.shape) | ||
for i in range(len(block_size)-1): | ||
block_size[i] = 1 | ||
return block_size | ||
|
||
input_mapping_type = MappingType.SYMMETRIC | ||
input_target_dtype = torch.int8 | ||
input_eps = 1e-5 | ||
input_quant_min = -127 | ||
input_quant_max = 127 | ||
layout_type = PlainLayoutType() | ||
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) | ||
|
||
block_size = get_weight_block_size(weight) | ||
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) | ||
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) | ||
return weight | ||
|
||
@classmethod | ||
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): | ||
""" | ||
|
@@ -298,7 +342,8 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): | |
) | ||
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") | ||
with torch.no_grad(): | ||
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data) | ||
w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t() | ||
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8) | ||
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") | ||
|
||
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op | ||
|
@@ -313,18 +358,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): | |
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") | ||
return res_f | ||
|
||
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): | ||
class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): | ||
""" | ||
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight | ||
""" | ||
@classmethod | ||
def from_float(cls, weight): | ||
mapping_type = MappingType.SYMMETRIC | ||
target_dtype = torch.int8 | ||
eps = torch.finfo(torch.float32).eps | ||
zero_point_dtype = torch.int64 | ||
block_size = (1, weight.shape[1]) | ||
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) | ||
|
||
|
||
class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin): | ||
class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin): | ||
""" | ||
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that | ||
uses a different kernel | ||
""" | ||
@staticmethod | ||
def _quantized_op(act_mat, w_qtensor, bias): | ||
def _quantized_linear_op(act_mat, w_qtensor, bias): | ||
""" | ||
Performs the quantized linear operations | ||
|
||
|
@@ -339,8 +393,8 @@ def _quantized_op(act_mat, w_qtensor, bias): | |
orig_dtype = act_mat.dtype | ||
orig_shape = act_mat.shape | ||
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) | ||
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) | ||
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales | ||
y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2) | ||
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale | ||
if bias is not None: | ||
y += bias | ||
return y.to(orig_dtype) | ||
|
@@ -352,14 +406,14 @@ def _autoquant_test(cls, act_mat, *args): | |
return torch.inf | ||
return super()._autoquant_test(act_mat, *args) | ||
|
||
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): | ||
class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin): | ||
""" | ||
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that | ||
uses a different kernel | ||
""" | ||
def _quantized_op(act_mat, w_qtensor, bias): | ||
def _quantized_linear_op(act_mat, w_qtensor, bias): | ||
orig_shape = act_mat.shape | ||
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) | ||
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale) | ||
y=y.reshape(*orig_shape[:-1], y.shape[-1]) | ||
if bias is not None: | ||
y += bias | ||
|
@@ -377,7 +431,7 @@ def __init__(self): | |
super().__init__() | ||
|
||
@staticmethod | ||
def _quantized_op(act_mat, w_qtensor, bias): | ||
def _quantized_linear_op(act_mat, w_qtensor, bias): | ||
return torch.nn.functional.linear(act_mat, w_qtensor, bias) | ||
|
||
@classmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are these? should these be enabled or removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a todo