Skip to content

Commit d222713

Browse files
committed
refactor affine fake quantized tensor
1 parent f3d86c6 commit d222713

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
MappingType,
1010
)
1111
from torch.utils._python_dispatch import return_and_correct_aliasing
12-
from torchao.dtypes.utils import (
13-
_implements,
14-
_dispatch__torch_function__,
15-
_dispatch__torch_dispatch__,
16-
)
12+
from torchao.utils import TorchAOBaseTensor
1713
from .utils import (
1814
_GenericFakeQuantize,
1915
_UnwrapAffineFakeQuantizedTensor,
@@ -80,7 +76,7 @@ def backward(ctx, gy):
8076
return gy, None, None, None, None, None, None, None, None, None, None
8177

8278

83-
class AffineFakeQuantizedTensor(torch.Tensor):
79+
class AffineFakeQuantizedTensor(TorchAOBaseTensor):
8480
"""
8581
Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor
8682
with an affine transformation:
@@ -179,15 +175,15 @@ def _get_to_kwargs(self, *args, **kwargs):
179175
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
180176
device = self.device if device is None else device
181177
dtype = self.dtype if dtype is None else dtype
182-
memory_format = (
178+
memory_format = (
183179
memory_format if memory_format is not None else torch.preserve_format
184-
)
185-
kwargs = {
180+
)
181+
kwargs = {
186182
"device": device,
187183
"dtype": dtype,
188184
"memory_format": memory_format,
189185
"requires_grad": self.requires_grad,
190-
}
186+
}
191187
return kwargs
192188

193189
def to(self, *args, **kwargs):
@@ -227,10 +223,6 @@ def _create_new(self, new_value: torch.Tensor):
227223
requires_grad=False,
228224
)
229225

230-
implements = classmethod(_implements)
231-
__torch_function__ = classmethod(_dispatch__torch_function__)
232-
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
233-
234226
implements = AffineFakeQuantizedTensor.implements
235227

236228

0 commit comments

Comments
 (0)