Skip to content

Commit 50db29f

Browse files
committed
update
1 parent 4e4d11c commit 50db29f

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

torchao/quantization/quant_api.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@
9898
to_weight_tensor_with_linear_activation_quantization_metadata,
9999
)
100100
from torchao.utils import (
101-
TorchAOBaseTensor,
102101
_ConfigDeprecationWrapper,
103102
is_MI300,
104103
is_sm_at_least_89,
@@ -416,16 +415,17 @@ def _embedding_extra_repr(self):
416415
return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}"
417416

418417

419-
def _module_extra_repr(self, original_extra_repr):
420-
module_torchao_param_extra_repr = [
421-
f"{name}={_quantization_type(getattr(self, name))}"
422-
for name, param in self.named_parameters()
423-
if isinstance(param, TorchAOBaseTensor)
424-
]
418+
def _module_extra_repr(self, original_extra_repr, parameter_name):
419+
module_torchao_extra_repr = []
420+
425421
original_extra_repr_str = original_extra_repr()
426422
if len(original_extra_repr_str) > 0:
427-
module_torchao_param_extra_repr.insert(0, original_extra_repr_str)
428-
return ", ".join(module_torchao_param_extra_repr)
423+
module_torchao_extra_repr.append(original_extra_repr_str)
424+
425+
module_torchao_extra_repr.append(
426+
f"{parameter_name}={_quantization_type(getattr(self, parameter_name))}"
427+
)
428+
return ", ".join(module_torchao_extra_repr)
429429

430430

431431
def _get_linear_subclass_inserter(
@@ -1396,7 +1396,12 @@ def _int8_weight_only_transform(
13961396
torch.nn.Parameter(quantized_tensor, requires_grad=False),
13971397
)
13981398
module.extra_repr = types.MethodType(
1399-
partial(_module_extra_repr, original_extra_repr=module.extra_repr), module
1399+
partial(
1400+
_module_extra_repr,
1401+
original_extra_repr=module.extra_repr,
1402+
parameter_name=parameter_name,
1403+
),
1404+
module,
14001405
)
14011406
return module
14021407

@@ -1692,7 +1697,12 @@ def _float8_weight_only_transform(
16921697
torch.nn.Parameter(quantized_tensor, requires_grad=False),
16931698
)
16941699
module.extra_repr = types.MethodType(
1695-
partial(_module_extra_repr, original_extra_repr=module.extra_repr), module
1700+
partial(
1701+
_module_extra_repr,
1702+
original_extra_repr=module.extra_repr,
1703+
parameter_name=parameter_name,
1704+
),
1705+
module,
16961706
)
16971707
return module
16981708

@@ -1941,7 +1951,12 @@ def _float8_dynamic_activation_float8_weight_transform(
19411951
torch.nn.Parameter(quantized_tensor, requires_grad=False),
19421952
)
19431953
module.extra_repr = types.MethodType(
1944-
partial(_module_extra_repr, original_extra_repr=module.extra_repr), module
1954+
partial(
1955+
_module_extra_repr,
1956+
original_extra_repr=module.extra_repr,
1957+
parameter_name=parameter_name,
1958+
),
1959+
module,
19451960
)
19461961
return module
19471962

0 commit comments

Comments
 (0)