Skip to content

Add decorator for custom op and inductor decomp registration #408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
choose_qparams_affine,
quantize_affine,
dequantize_affine,
MappingType,
)
from torchao.quantization.utils import (
dequantize_per_channel,
Expand Down
4 changes: 0 additions & 4 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.subclass import (
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down
32 changes: 15 additions & 17 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
quantize_affine,
dequantize_affine,
choose_qparams_affine,
MappingType,
ZeroPointDomain,
)
# TODO: remove test for utils?
from torchao.quantization.utils import (
Expand Down Expand Up @@ -167,7 +165,7 @@ def test_choose_qparams_group_sym(self):
we don't include it here. We may just replace it with per block quant
"""
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (1, 2)
eps = torch.finfo(torch.float32).eps
Expand All @@ -183,7 +181,7 @@ def test_choose_qparams_group_sym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -198,7 +196,7 @@ def test_choose_qparams_token_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
Expand All @@ -217,7 +215,7 @@ def test_choose_qparams_tensor_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
Expand All @@ -237,7 +235,7 @@ def test_quantize_activation_per_token_abs_max(self):
input = torch.randn(10, 10)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)

mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
block_size = list(input.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
Expand Down Expand Up @@ -278,7 +276,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -303,7 +301,7 @@ def test_quantize_dequantize_group_sym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 1)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -327,7 +325,7 @@ def test_quantize_dequantize_channel_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
output_dtype = torch.float32
Expand All @@ -351,7 +349,7 @@ def test_quantize_dequantize_tensor_asym(self):
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_channel_asym_4d(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (3, 3, 1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -373,7 +371,7 @@ def test_quantize_dequantize_channel_asym_4d(self):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (3, 3, 2, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
Expand All @@ -384,7 +382,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):

def test_choose_qparams_tensor_asym_eps(self):
input = torch.zeros(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)
Expand All @@ -406,7 +404,7 @@ def test_raises(self):
"""Make sure some errors are raised when user requested an unsupported type of quantization
"""
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (10, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)
Expand All @@ -425,7 +423,7 @@ def test_not_preserve_zero_not_supported(self):
"""Making sure preserve_zero == False is not supported for symmetric quant"""
input = torch.randn(10, 256)
n_bit = 4
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
Expand Down Expand Up @@ -453,7 +451,7 @@ def test_get_groupwise_affine_qparams(self):
n_bit = 4
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)

mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
Expand All @@ -473,7 +471,7 @@ def test_get_groupwise_affine_qparams(self):
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
zero_point_domain="float",
)

self.assertTrue(torch.equal(scale, scale_ref))
Expand Down
20 changes: 9 additions & 11 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
choose_qparams_affine,
quantize_affine,
dequantize_affine,
ZeroPointDomain,
MappingType,
int_scaled_matmul,
)
from torchao.quantization.utils import (
Expand Down Expand Up @@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor):
shape (torch.Size): the shape for the Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float"
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
default is "int"
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
dtype: dtype for external representation of the tensor, e.g. torch.float32
"""
Expand All @@ -116,7 +114,7 @@ def __new__(
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
dtype=None,
strides=None,
):
Expand All @@ -138,7 +136,7 @@ def __init__(
shape: torch.Size,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
dtype=None,
strides=None,
):
Expand Down Expand Up @@ -184,7 +182,7 @@ def __tensor_unflatten__(
def from_float(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
mapping_type: str,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
Expand All @@ -193,7 +191,7 @@ def from_float(
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: str = "int",
extended_layout: str = "plain",
# TODO: this is only for "tensor_core_tiled", need to figure out
# the proper API for this arg
Expand Down Expand Up @@ -520,7 +518,7 @@ def get_plain(self):
target_dtype = torch.int32
quant_min = 0
quant_max = 15
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "int"
assert len(block_size) == 2 and block_size[0] == 1
groupsize = block_size[-1]
dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero)
Expand Down Expand Up @@ -597,7 +595,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.zero_point_domain == "float" and
weight_qtensor.extended_layout == "tensor_core_tiled"
):
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
Expand Down Expand Up @@ -640,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
weight_qtensor.zero_point_domain == "int" and
weight_qtensor.extended_layout == "plain"
):
# TODO: enable cpu and mps efficient path
Expand Down
23 changes: 9 additions & 14 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
to_linear_act_quantized,
)

from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Expand Down Expand Up @@ -270,15 +266,15 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "float"

apply_weight_quant = lambda x: to_affine_quantized(
x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
Expand Down Expand Up @@ -319,7 +315,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight):
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
block_size = (1, group_size)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
Expand All @@ -336,7 +332,7 @@ def get_per_token_block_size(x):
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_mapping_type = "asymmetric"
input_target_dtype = torch.int8
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

Expand All @@ -360,16 +356,15 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
def apply_int4_weight_only_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
mapping_type = "asymmetric"
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_domain = "float"
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles)

return apply_int4_weight_only_quant
Expand All @@ -383,7 +378,7 @@ def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
Expand All @@ -406,7 +401,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
mapping_type = "symmetric"
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
Expand All @@ -420,7 +415,7 @@ def get_per_token_block_size(x):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_mapping_type = "symmetric"
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
Expand Down
Loading
Loading