|
9 | 9 | MappingType,
|
10 | 10 | )
|
11 | 11 | from torch.utils._python_dispatch import return_and_correct_aliasing
|
12 |
| -from torchao.dtypes.utils import ( |
13 |
| - _implements, |
14 |
| - _dispatch__torch_function__, |
15 |
| - _dispatch__torch_dispatch__, |
16 |
| -) |
| 12 | +from torchao.utils import TorchAOBaseTensor |
17 | 13 | from .utils import (
|
18 | 14 | _GenericFakeQuantize,
|
19 | 15 | _UnwrapAffineFakeQuantizedTensor,
|
@@ -80,7 +76,7 @@ def backward(ctx, gy):
|
80 | 76 | return gy, None, None, None, None, None, None, None, None, None, None
|
81 | 77 |
|
82 | 78 |
|
83 |
| -class AffineFakeQuantizedTensor(torch.Tensor): |
| 79 | +class AffineFakeQuantizedTensor(TorchAOBaseTensor): |
84 | 80 | """
|
85 | 81 | Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor
|
86 | 82 | with an affine transformation:
|
@@ -179,15 +175,15 @@ def _get_to_kwargs(self, *args, **kwargs):
|
179 | 175 | device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
|
180 | 176 | device = self.device if device is None else device
|
181 | 177 | dtype = self.dtype if dtype is None else dtype
|
182 |
| - memory_format = ( |
| 178 | + memory_format = ( |
183 | 179 | memory_format if memory_format is not None else torch.preserve_format
|
184 |
| - ) |
185 |
| - kwargs = { |
| 180 | + ) |
| 181 | + kwargs = { |
186 | 182 | "device": device,
|
187 | 183 | "dtype": dtype,
|
188 | 184 | "memory_format": memory_format,
|
189 | 185 | "requires_grad": self.requires_grad,
|
190 |
| - } |
| 186 | + } |
191 | 187 | return kwargs
|
192 | 188 |
|
193 | 189 | def to(self, *args, **kwargs):
|
@@ -227,10 +223,6 @@ def _create_new(self, new_value: torch.Tensor):
|
227 | 223 | requires_grad=False,
|
228 | 224 | )
|
229 | 225 |
|
230 |
| - implements = classmethod(_implements) |
231 |
| - __torch_function__ = classmethod(_dispatch__torch_function__) |
232 |
| - __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) |
233 |
| - |
234 | 226 | implements = AffineFakeQuantizedTensor.implements
|
235 | 227 |
|
236 | 228 |
|
|
0 commit comments