|
98 | 98 | to_weight_tensor_with_linear_activation_quantization_metadata, |
99 | 99 | ) |
100 | 100 | from torchao.utils import ( |
101 | | - TorchAOBaseTensor, |
102 | 101 | _ConfigDeprecationWrapper, |
103 | 102 | is_MI300, |
104 | 103 | is_sm_at_least_89, |
@@ -416,16 +415,17 @@ def _embedding_extra_repr(self): |
416 | 415 | return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}" |
417 | 416 |
|
418 | 417 |
|
419 | | -def _module_extra_repr(self, original_extra_repr): |
420 | | - module_torchao_param_extra_repr = [ |
421 | | - f"{name}={_quantization_type(getattr(self, name))}" |
422 | | - for name, param in self.named_parameters() |
423 | | - if isinstance(param, TorchAOBaseTensor) |
424 | | - ] |
| 418 | +def _module_extra_repr(self, original_extra_repr, parameter_name): |
| 419 | + module_torchao_extra_repr = [] |
| 420 | + |
425 | 421 | original_extra_repr_str = original_extra_repr() |
426 | 422 | if len(original_extra_repr_str) > 0: |
427 | | - module_torchao_param_extra_repr.insert(0, original_extra_repr_str) |
428 | | - return ", ".join(module_torchao_param_extra_repr) |
| 423 | + module_torchao_extra_repr.append(original_extra_repr_str) |
| 424 | + |
| 425 | + module_torchao_extra_repr.append( |
| 426 | + f"{parameter_name}={_quantization_type(getattr(self, parameter_name))}" |
| 427 | + ) |
| 428 | + return ", ".join(module_torchao_extra_repr) |
429 | 429 |
|
430 | 430 |
|
431 | 431 | def _get_linear_subclass_inserter( |
@@ -1396,7 +1396,12 @@ def _int8_weight_only_transform( |
1396 | 1396 | torch.nn.Parameter(quantized_tensor, requires_grad=False), |
1397 | 1397 | ) |
1398 | 1398 | module.extra_repr = types.MethodType( |
1399 | | - partial(_module_extra_repr, original_extra_repr=module.extra_repr), module |
| 1399 | + partial( |
| 1400 | + _module_extra_repr, |
| 1401 | + original_extra_repr=module.extra_repr, |
| 1402 | + parameter_name=parameter_name, |
| 1403 | + ), |
| 1404 | + module, |
1400 | 1405 | ) |
1401 | 1406 | return module |
1402 | 1407 |
|
@@ -1692,7 +1697,12 @@ def _float8_weight_only_transform( |
1692 | 1697 | torch.nn.Parameter(quantized_tensor, requires_grad=False), |
1693 | 1698 | ) |
1694 | 1699 | module.extra_repr = types.MethodType( |
1695 | | - partial(_module_extra_repr, original_extra_repr=module.extra_repr), module |
| 1700 | + partial( |
| 1701 | + _module_extra_repr, |
| 1702 | + original_extra_repr=module.extra_repr, |
| 1703 | + parameter_name=parameter_name, |
| 1704 | + ), |
| 1705 | + module, |
1696 | 1706 | ) |
1697 | 1707 | return module |
1698 | 1708 |
|
@@ -1941,7 +1951,12 @@ def _float8_dynamic_activation_float8_weight_transform( |
1941 | 1951 | torch.nn.Parameter(quantized_tensor, requires_grad=False), |
1942 | 1952 | ) |
1943 | 1953 | module.extra_repr = types.MethodType( |
1944 | | - partial(_module_extra_repr, original_extra_repr=module.extra_repr), module |
| 1954 | + partial( |
| 1955 | + _module_extra_repr, |
| 1956 | + original_extra_repr=module.extra_repr, |
| 1957 | + parameter_name=parameter_name, |
| 1958 | + ), |
| 1959 | + module, |
1945 | 1960 | ) |
1946 | 1961 | return module |
1947 | 1962 |
|
|
0 commit comments