Skip to content

Commit 43a1f46

Browse files
authored
Add __str__ to FqnToConfig to make printing more readable (#3323)
* Adds __str__ to FqnToConfig to make printing more readable Summary: att, adds `__str__` method to `FqnToConfig` so that printing is more legible. For some config: ```python config = FqnToConfig({ "model.layers.fig.1.1": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), "model.layers.fig.1.3": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), "model.layers.fig.8.3": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` the output will be: ``` FqnToConfig({ 'model.layers.fig.1.1': Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerRow(dim=-1), PerRow(dim=-1)], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>, set_inductor_config=True, version=2), 'model.layers.fig.1.3': Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerRow(dim=-1), PerRow(dim=-1)], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>, set_inductor_config=True, version=2), 'model.layers.fig.8.3': Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerRow(dim=-1), PerRow(dim=-1)], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>, set_inductor_config=True, version=2), }) ``` also adds in a test so that you cannot specify both fqn_to_config and module_fqn_to_config unless they are both equal. Test Plan: ``` pytest test/quantization/test_quant_api.py -k test_fqn_config_module_config_and_fqn_config_both_specified ``` Reviewers: Subscribers: Tasks: Tags: * fix ruff check
1 parent 1a9b6f4 commit 43a1f46

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

test/quantization/test_quant_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,13 @@ def __init__(self):
11781178
assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
11791179
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
11801180

1181+
def test_fqn_config_module_config_and_fqn_config_both_specified(self):
1182+
with self.assertRaises(ValueError):
1183+
FqnToConfig(
1184+
fqn_to_config={"test": Float8WeightOnlyConfig()},
1185+
module_fqn_to_config={"test2": Float8WeightOnlyConfig()},
1186+
)
1187+
11811188

11821189
if __name__ == "__main__":
11831190
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,6 +2466,15 @@ class FqnToConfig(AOBaseConfig):
24662466
def __post_init__(self):
24672467
torch._C._log_api_usage_once("torchao.quantization.FqnToConfig")
24682468

2469+
if (
2470+
len(self.fqn_to_config) > 0
2471+
and len(self.module_fqn_to_config) > 0
2472+
and self.fqn_to_config != self.module_fqn_to_config
2473+
):
2474+
raise ValueError(
2475+
"`fqn_to_config` and `module_fqn_to_config` are both specified and are not equal!"
2476+
)
2477+
24692478
# This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object.
24702479
if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0:
24712480
self.fqn_to_config = self.module_fqn_to_config
@@ -2479,6 +2488,18 @@ def __post_init__(self):
24792488
"Config Deprecation: _default is deprecated and will no longer be supported in a future release. Please see https://github.com/pytorch/ao/issues/3229 for more details."
24802489
)
24812490

2491+
def __str__(self):
2492+
return "\n".join(
2493+
[
2494+
"FqnToConfig({",
2495+
*(
2496+
f" '{key}':\n {value},"
2497+
for key, value in self.fqn_to_config.items()
2498+
),
2499+
"})",
2500+
]
2501+
)
2502+
24822503

24832504
# maintain BC
24842505
ModuleFqnToConfig = FqnToConfig

0 commit comments

Comments
 (0)