diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 83c7d22fb4..666f7c1947 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -480,6 +480,10 @@ def _apply_fn_to_data(self, fn): self.scale_and_zero = fn(self.scale_and_zero) return self + def __repr__(self): + int_data, scale, zero_point = self.get_plain() + return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs