Skip to content

Commit d5aed60

Browse files
committed
Move more utils to TorchAOBaseTensor (pytorch#784)
* Move more utils to TorchAOBaseTensor Summary: This moves over _implements, _dispatch__torch_dispatch__, _dispatch__torch_function__, _register_layout_cls and _get_layout_tensor_constructor to `TorchAOBaseTensor` so when people inherit from this, they can get these utils directly Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py rely on CI for other tests Reviewers: Subscribers: Tasks: Tags:
1 parent 4d88ec3 commit d5aed60

File tree

14 files changed

+205
-233
lines changed

14 files changed

+205
-233
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_quantize_4bit_with_qmap_compile(self, device):
7575

7676

7777
class TestOptim(TestCase):
78-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
78+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.3")
7979
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
8080
@parametrize("dtype", [torch.float32, torch.bfloat16])
8181
@parametrize("device", _DEVICES)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@
2121
)
2222
from torch.utils._python_dispatch import return_and_correct_aliasing
2323
from torchao.dtypes.utils import (
24-
_implements,
25-
_dispatch__torch_function__,
26-
_dispatch__torch_dispatch__,
27-
_register_layout_cls,
28-
_get_layout_tensor_constructor,
2924
LayoutType,
3025
PlainLayoutType,
3126
is_device,
@@ -405,7 +400,8 @@ def _apply_fn_to_data(self, fn):
405400
strides=self.stride(),
406401
)
407402

408-
implements = classmethod(_implements)
403+
# following are the comments for __torch_function__/__torch_dispatch__, we can clean this up
404+
# a bit later
409405
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
410406
# 1. we'll add cpu/cuda version (int4mm etc.)
411407
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
@@ -417,19 +413,13 @@ def _apply_fn_to_data(self, fn):
417413
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
418414
# kernels in CPU as well, see the note above
419415
# 2 - we're given non-floats - quantizing long to int8 is crazy
420-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
421-
__torch_function__ = classmethod(_dispatch__torch_function__)
422416

423417

424418
######################################################
425419
# LayoutType and Layout Tensor Subclass Registration #
426420
######################################################
427-
428-
def register_layout_cls(layout_type_class: type(LayoutType)):
429-
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
430-
431-
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
432-
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)
421+
register_layout_cls = AffineQuantizedTensor.register_layout_cls
422+
get_layout_tensor_constructor = AffineQuantizedTensor.get_layout_tensor_constructor
433423

434424
@dataclass(frozen=True)
435425
class SemiSparseLayoutType(LayoutType):

torchao/dtypes/fpx/fpx.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,10 @@
88
from torchao.ops import quant_llm_linear
99
from torchao.dtypes.utils import (
1010
LayoutType,
11-
_implements,
12-
_dispatch__torch_function__,
13-
_dispatch__torch_dispatch__,
1411
)
1512
from torchao.quantization.quant_api import _get_linear_subclass_inserter
1613
from dataclasses import dataclass
1714
from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls
18-
from torchao.utils import TorchAOBaseTensor
1915

2016

2117
aten = torch.ops.aten

torchao/dtypes/uintx/Uintx.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
from .bitpacking import pack, unpack
77
from torchao.dtypes.utils import (
88
LayoutType,
9-
_implements,
10-
_dispatch__torch_function__,
11-
_dispatch__torch_dispatch__,
129
)
10+
from torchao.utils import TorchAOBaseTensor
1311
from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls
1412
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
1513

@@ -35,7 +33,7 @@
3533
print("uintx feature need torch 2.3+, please upgrade pytorch")
3634

3735

38-
class UintxTensor(torch.Tensor):
36+
class UintxTensor(TorchAOBaseTensor):
3937
"""
4038
Splits int data into packed shards based on bit size
4139
fields:
@@ -99,10 +97,6 @@ def __tensor_unflatten__(
9997
packed_shape, bit_width, pack_dim = tensor_attributes
10098
return cls(shards, packed_shape, bit_width, pack_dim)
10199

102-
implements = classmethod(_implements)
103-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
104-
__torch_function__ = classmethod(_dispatch__torch_function__)
105-
106100
def get_plain(self):
107101
return unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)
108102

torchao/dtypes/utils.py

Lines changed: 1 addition & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,6 @@
11
import torch
2-
from typing import Dict, Callable, Union, Tuple, Optional
3-
from collections import defaultdict
4-
import functools
2+
from typing import Union, Tuple
53
from dataclasses import dataclass
6-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
7-
8-
"""
9-
Helper function for implementing aten op or torch function dispatch
10-
and dispatching to these implementations.
11-
"""
12-
def _implements(cls, aten_ops_or_torch_fns):
13-
"""Use this decorator to implement a function for an aten ops in __torch_dispatch__
14-
(if user passed in a list of ops)
15-
or torch function in __torch_function__ (if user passed in a single object)
16-
17-
class MyTensor(torch.Tensor):
18-
...
19-
implements = classmethod(_implements)
20-
21-
implements = MyTensor.implements
22-
23-
@implements(torch.nn.functional.linear):
24-
def _(func, types, args, kwargs):
25-
...
26-
27-
"""
28-
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
29-
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}
30-
31-
if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
32-
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
33-
def decorator(func):
34-
for op in aten_ops_or_torch_fns:
35-
@functools.wraps(op)
36-
def wrapper(f, types, args, kwargs):
37-
return func(f, types, args, kwargs)
38-
39-
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
40-
return func
41-
return decorator
42-
43-
def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None):
44-
"""Use this util function for a common `__torch_function__` implementation
45-
that dispatches to ops/functions registered with `_implements`
46-
47-
class MyTensor(torch.Tensor):
48-
...
49-
__torch_function__ = classmethod(_dispatch__torch_function__)
50-
"""
51-
kwargs = {} if kwargs is None else kwargs
52-
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
53-
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
54-
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
55-
56-
with torch._C.DisableTorchFunctionSubclass():
57-
return func(*args, **kwargs)
58-
59-
def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
60-
"""Use this util function for a common `__torch_dispatch__` implementation
61-
that dispatches to ops/functions registered with `_implements`
62-
63-
class MyTensor(torch.Tensor):
64-
...
65-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
66-
"""
67-
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
68-
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
69-
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
70-
71-
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")
72-
734

745
"""
756
Base class for different LayoutType, should not be instantiated directly
@@ -101,52 +32,6 @@ def extra_repr(self) -> str:
10132
class PlainLayoutType(LayoutType):
10233
pass
10334

104-
"""
105-
layout tensor constructor registration for different tensor subclassesa
106-
107-
first key is a tensor subclass type like AffineQuantizedTensor
108-
second key is an extended layout string, like tensor_core_tiled
109-
value is a constructor for the LayoutTensor class, e.g. TensorCoreTiledAQTLayout.from_plain
110-
"""
111-
_LAYOUT_CONSTRUCTOR_TABLE: Dict[Callable, Dict[type(LayoutType), Callable]] = defaultdict(dict)
112-
113-
def _register_layout_cls(cls: Callable, layout_type_class: type(LayoutType)):
114-
"""Helper function for layout registrations, this is used to implement
115-
register_layout_cls decorator for each tensor subclass, see aqt.py for example usage
116-
117-
Args:
118-
cls: Tensor subclass type
119-
layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType`
120-
121-
Returns:
122-
a decorator that registers the layout tensor constructor in the table
123-
"""
124-
def decorator(layout_cls):
125-
_LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] = layout_cls.from_plain
126-
if TORCH_VERSION_AT_LEAST_2_5:
127-
# Allow serialization to work for models uses this layout tensor subclass
128-
torch.serialization.add_safe_globals([layout_type_class, layout_cls])
129-
return layout_cls
130-
return decorator
131-
132-
def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(LayoutType)) -> Callable:
133-
"""Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `layout_type_class`
134-
`layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType`
135-
136-
Args:
137-
cls: Tensor subclass type
138-
layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType`
139-
140-
Returns:
141-
layout tensor subclass constructor for the layout_type_class
142-
"""
143-
if cls not in _LAYOUT_CONSTRUCTOR_TABLE:
144-
raise ValueError(f"no registered layout class constructor for: {cls}")
145-
if layout_type_class not in _LAYOUT_CONSTRUCTOR_TABLE[cls]:
146-
raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}")
147-
148-
return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class]
149-
15035
def is_device(target_device_str: str, device: Union[str, torch.device]):
15136
return torch.device(device).type == target_device_str
15237

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch import Tensor
5-
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
5+
from torchao.utils import TorchAOBaseTensor
66

77
from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap
88

@@ -18,8 +18,7 @@
1818
QMAP_UNSIGNED = torch.linspace(0, 1, 17)[1:].tolist() # no zero
1919

2020

21-
class OptimState4bit(Tensor):
22-
implements = classmethod(_implements)
21+
class OptimState4bit(TorchAOBaseTensor):
2322
tensor_attrs = ["codes", "scale", "qmap"]
2423

2524
@staticmethod
@@ -80,8 +79,6 @@ def __repr__(self):
8079
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
8180
)
8281

83-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
84-
8582

8683
@OptimState4bit.implements(aten.copy_.default)
8784
def _(func, types, args, kwargs):

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
3+
from torchao.utils import TorchAOBaseTensor
44

55
from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap
66

@@ -13,8 +13,7 @@
1313
QMAP_UNSIGNED = create_dynamic_map(signed=False)
1414

1515

16-
class OptimState8bit(Tensor):
17-
implements = classmethod(_implements)
16+
class OptimState8bit(TorchAOBaseTensor):
1817
tensor_attrs = ["codes", "scale", "qmap"]
1918

2019
@staticmethod
@@ -66,8 +65,6 @@ def __repr__(self):
6665
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
6766
)
6867

69-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
70-
7168

7269
@OptimState8bit.implements(aten.copy_.default)
7370
def _(func, types, args, kwargs):

torchao/prototype/low_bit_optim/subclass_fp8.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
3+
from torchao.utils import TorchAOBaseTensor
44

55

66
aten = torch.ops.aten
@@ -21,8 +21,7 @@ def quantize_fp8(input: Tensor, block_size: int):
2121

2222
# NOTE: FP8 sign bit is redundant for unsigned optim state.
2323
# we may investigate how to use it to increase range/precision for unsigned optim state.
24-
class OptimStateFp8(Tensor):
25-
implements = classmethod(_implements)
24+
class OptimStateFp8(TorchAOBaseTensor):
2625
tensor_attrs = ["codes", "scale"]
2726

2827
@staticmethod
@@ -72,8 +71,6 @@ def __repr__(self):
7271
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
7372
)
7473

75-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
76-
7774

7875
@OptimStateFp8.implements(aten.copy_.default)
7976
def _(func, types, args, kwargs):

torchao/prototype/quantized_training/int8.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from torch import Tensor, nn
55
from torch.utils._python_dispatch import return_and_correct_aliasing
66

7-
from torchao.dtypes.utils import _dispatch__torch_dispatch__, _dispatch__torch_function__, _implements
7+
from torchao.utils import TorchAOBaseTensor
88

99

1010
aten = torch.ops.aten
1111
c10d_functional = torch.ops.c10d_functional
1212
_c10d_functional = torch.ops._c10d_functional
1313

1414

15-
class Int8QTLinearWeight(Tensor):
15+
class Int8QTLinearWeight(TorchAOBaseTensor):
1616
"""INT8 symmetric quantization weight, with absmax scaling [-127, 127]. The main difference
1717
of this tensor subclass from AffineQuantizedTensor:
1818
1. `F.linear` is differentiable i.e. backward is defined.
@@ -22,10 +22,6 @@ class Int8QTLinearWeight(Tensor):
2222
for more details.
2323
"""
2424

25-
implements = classmethod(_implements)
26-
__torch_function__ = classmethod(_dispatch__torch_function__)
27-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
28-
2925
@staticmethod
3026
@torch._dynamo.disable
3127
def __new__(cls, int_data: Tensor, scale: Tensor):

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
import torch
2-
from torchao.dtypes.utils import (
3-
_implements,
4-
_dispatch__torch_function__,
5-
_dispatch__torch_dispatch__,
6-
)
72
from typing import Callable
83
from torch.utils._python_dispatch import return_and_correct_aliasing
94
from torchao.utils import (
@@ -94,10 +89,6 @@ def to(self, *args, **kwargs):
9489
self.input_quant_func,
9590
)
9691

97-
implements = classmethod(_implements)
98-
__torch_function__ = classmethod(_dispatch__torch_function__)
99-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
100-
10192
implements = LinearActivationQuantizedTensor.implements
10293

10394
@implements(torch.nn.functional.linear)

0 commit comments

Comments
 (0)