Skip to content

Add developer guide code to tutorials #588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 79 additions & 68 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,71 +22,17 @@
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
PlainLayoutType,
is_device,
)
from typing import ClassVar
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

aten = torch.ops.aten

@dataclass(frozen=True)
class PlainLayoutType(LayoutType):
pass

@dataclass(frozen=True)
class SemiSparseLayoutType(LayoutType):

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
# prune to 2:4 if not already
temp = input.detach()
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
return temp


@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
orig_out_features, orig_in_features = input.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input = torch.nn.functional.pad(
input,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
return input

def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min is None or aqt.quant_min == -128 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_int8_reduced_range(aqt):
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min == -127 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_uint4(aqt):
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
# TODO: use torch.uint4
return (
aqt.layout_tensor.dtype == torch.int32 and
aqt.quant_min is None or aqt.quant_min == 0 and
aqt.quant_max is None or aqt.quant_max == 15
)

###############################
# Base Layout Tensor Subclass #
###############################
class AQTLayout(torch.Tensor):
"""
Base class for the layout tensor for `AffineQuantizedTensor`
Expand Down Expand Up @@ -126,6 +72,10 @@ def _get_to_kwargs(self, *args, **kwargs):
}
return kwargs

##############################
# Tensor Subclass Definition #
##############################

class AffineQuantizedTensor(torch.Tensor):
"""
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
Expand Down Expand Up @@ -337,7 +287,6 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)


implements = classmethod(_implements)
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
Expand All @@ -353,14 +302,46 @@ def _apply_fn_to_data(self, fn):
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

implements = AffineQuantizedTensor.implements

######################################################
# LayoutType and Layout Tensor Subclass Registration #
######################################################

def register_layout_cls(layout_type_class: type(LayoutType)):
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)

def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)

@dataclass(frozen=True)
class SemiSparseLayoutType(LayoutType):

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
# prune to 2:4 if not already
temp = input.detach()
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
return temp


@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
orig_out_features, orig_in_features = input.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input = torch.nn.functional.pad(
input,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
return input

def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


@register_layout_cls(PlainLayoutType)
class PlainAQTLayout(AQTLayout):
"""
Expand Down Expand Up @@ -487,7 +468,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)

def get_plain(self):
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
# the identity matrix to get the original dense matrix. This is slow though.
cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0])
int_data_expanded = torch._cslt_sparse_mm(self.int_data,
Expand All @@ -507,7 +488,7 @@ def from_plain(
assert isinstance(layout_type, SemiSparseLayoutType)
int_data_compressed = torch._cslt_compress(int_data)
return cls(int_data_compressed, scale, zero_point, layout_type)


@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
Expand Down Expand Up @@ -654,6 +635,34 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_layout_type(self) -> LayoutType:
return self.layout_type

#####################################################
# torch functional and aten operator implementation #
#####################################################

def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min is None or aqt.quant_min == -128 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_int8_reduced_range(aqt):
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min == -127 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_uint4(aqt):
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
# TODO: use torch.uint4
return (
aqt.layout_tensor.dtype == torch.int32 and
aqt.quant_min is None or aqt.quant_min == 0 and
aqt.quant_max is None or aqt.quant_max == 15
)

def _quantized_linear_op(input_tensor, weight_qtensor, bias):
"""
Quantized version of F.linear operator
Expand Down Expand Up @@ -811,8 +820,10 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
raise NotImplementedError("No specialized dispatch found for quantized linear op")


implements = AffineQuantizedTensor.implements

@implements(torch.nn.functional.linear)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
Expand All @@ -831,7 +842,7 @@ def _(func, types, *args, **kwargs):
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements([aten.mm.default, aten.addmm.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

Expand Down Expand Up @@ -870,21 +881,21 @@ def _(func, types, *args, **kwargs):
return func(input_tensor, weight_tensor)

@implements([aten.detach.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements([aten.clone.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements([aten._to_copy.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
Expand All @@ -893,7 +904,7 @@ def _(func, types, *args, **kwargs):
)

@implements([aten.t.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
Expand Down
16 changes: 12 additions & 4 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _(func, types, args, kwargs):
def decorator(func):
for op in aten_ops_or_torch_fns:
@functools.wraps(op)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
def wrapper(f, types, args, kwargs):
return func(f, types, args, kwargs)

cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
return func
Expand All @@ -50,7 +50,7 @@ class MyTensor(torch.Tensor):
kwargs = {} if kwargs is None else kwargs
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
Expand All @@ -65,7 +65,7 @@ class MyTensor(torch.Tensor):
"""
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")

Expand All @@ -87,6 +87,14 @@ def __repr__(self):
def extra_repr(self) -> str:
return ""

"""
Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default
"""
@dataclass(frozen=True)
class PlainLayoutType(LayoutType):
pass


"""
layout tensor constructor registration for different tensor subclassesa

Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __repr__(self):


@OptimState4bit.implements(aten.copy_.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
dst = args[0]
src = args[1]

Expand All @@ -116,14 +116,14 @@ def _(func, types, *args, **kwargs):


@OptimState4bit.implements(aten.lerp.Scalar)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local() and for flattening tensor
@OptimState4bit.implements(aten.view.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x, shape = args

if tuple(x.shape) == tuple(shape):
Expand All @@ -142,7 +142,7 @@ def _(func, types, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __repr__(self):


@OptimState8bit.implements(aten.copy_.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
dst = args[0]
src = args[1]

Expand All @@ -98,14 +98,14 @@ def _(func, types, *args, **kwargs):


@OptimState8bit.implements(aten.lerp.Scalar)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimState8bit.implements(aten.view.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x, shape = args
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)

Expand All @@ -117,7 +117,7 @@ def _(func, types, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __repr__(self):


@OptimStateFp8.implements(aten.copy_.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
dst = args[0]
src = args[1]

Expand All @@ -102,14 +102,14 @@ def _(func, types, *args, **kwargs):


@OptimStateFp8.implements(aten.lerp.Scalar)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimStateFp8.implements(aten.view.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale)

Expand All @@ -121,7 +121,7 @@ def _(func, types, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimStateFp8):
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")
Expand Down
Loading
Loading