diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 83c7d22fb4..666f7c1947 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -480,6 +480,10 @@ def _apply_fn_to_data(self, fn): self.scale_and_zero = fn(self.scale_and_zero) return self + def __repr__(self): + int_data, scale, zero_point = self.get_plain() + return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 33821f1d82..31ab71f385 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -310,13 +310,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: """ if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if isinstance(apply_tensor_subclass, str): - if apply_tensor_subclass not in _APPLY_TS_TABLE: - raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}") - apply_tensor_subclass = _APPLY_TS_TABLE[apply_tensor_subclass] - - assert not isinstance(apply_tensor_subclass, str) - _replace_with_custom_fn_if_matches_filter( model, _get_linear_subclass_inserter(apply_tensor_subclass),