From 1264f7a0ce297c9e33a1b5865ea84b1fde4f0856 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 18 Jun 2024 09:42:07 -0700 Subject: [PATCH 1/3] Add repr method on tensor subclass [ghstack-poisoned] --- torchao/dtypes/aqt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 83c7d22fb4..4103024037 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -480,6 +480,9 @@ def _apply_fn_to_data(self, fn): self.scale_and_zero = fn(self.scale_and_zero) return self + def __repr__(self): + return f"TensorCoreTiledAQTLayout(packed_weight={self.packed_weight}, scale_and_zero={self.scale_and_zero})" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -495,6 +498,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): """ args[0].transposed = not args[0].transposed return return_and_correct_aliasing(func, args, kwargs, args[0]) + + breakpoint() raise NotImplementedError( f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported" From 5875f9ab2dca1421b2d0ac12abf64a161cfa192c Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 18 Jun 2024 09:42:47 -0700 Subject: [PATCH 2/3] Update on "Add repr method on tensor subclass" [ghstack-poisoned] --- torchao/dtypes/aqt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 4103024037..d35b186743 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -498,8 +498,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs): """ args[0].transposed = not args[0].transposed return return_and_correct_aliasing(func, args, kwargs, args[0]) - - breakpoint() raise NotImplementedError( f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported" From c22adf222840dff6e4f067d69e9dc279f3923e7a Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 18 Jun 2024 10:10:38 -0700 Subject: [PATCH 3/3] Update on "Add repr method on tensor subclass" Without this, it will go through default tensor __repr__ method which has some aten ops to print the tensor. This results in triggerring subclass's torch_dispatch. [ghstack-poisoned] --- torchao/dtypes/aqt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index d35b186743..666f7c1947 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -481,7 +481,8 @@ def _apply_fn_to_data(self, fn): return self def __repr__(self): - return f"TensorCoreTiledAQTLayout(packed_weight={self.packed_weight}, scale_and_zero={self.scale_and_zero})" + int_data, scale, zero_point = self.get_plain() + return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})" @classmethod def __torch_dispatch__(cls, func, types, args, kwargs):