Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,64 @@ def test_config_deprecation(self):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+")
class TestFqnToConfig(TestCase):
def test_fqn_to_config_repr_custom(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_parameter(
"x", torch.nn.Parameter(torch.randn(128, 128, dtype=torch.bfloat16))
)
self.register_parameter(
"y", torch.nn.Parameter(torch.randn(128, 128, dtype=torch.bfloat16))
)

custom_module = TestModule().cuda().eval()
custom_module_config = FqnToConfig(
{
"x": Float8DynamicActivationFloat8WeightConfig(
granularity=PerTensor(),
),
}
)
quantize_(
custom_module,
custom_module_config,
filter_fn=None,
)
expected_str = (
"TestModule(x=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs("
"float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, "
"hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), "
"self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, "
"pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))"
)
assert str(custom_module) == expected_str
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be a bit fragile? since small changes will break it, maybe use FileCheck()?


def test_fqn_to_config_repr_linear(self):
linear_model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
linear_quant_config = FqnToConfig(
{
"linear1.weight": Float8DynamicActivationFloat8WeightConfig(
granularity=PerTensor(),
),
}
)
quantize_(
linear_model,
linear_quant_config,
filter_fn=None,
)
expected_str = (
"Linear(in_features=64, out_features=32, bias=False, "
"weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs("
"float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, "
"hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), "
"self.block_size=[32, 64], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, "
"pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))"
)

assert str(linear_model.linear1) == expected_str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


def test_quantize_param_fqn_exact(self):
from transformers import AutoConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
Expand Down
53 changes: 46 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import OrderedDict as OrderedDictType

Expand Down Expand Up @@ -414,6 +415,19 @@ def _embedding_extra_repr(self):
return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}"


def _module_extra_repr(self, original_extra_repr, parameter_name):
module_torchao_extra_repr = []

original_extra_repr_str = original_extra_repr()
if len(original_extra_repr_str) > 0:
module_torchao_extra_repr.append(original_extra_repr_str)

module_torchao_extra_repr.append(
f"{parameter_name}={_quantization_type(getattr(self, parameter_name))}"
)
return ", ".join(module_torchao_extra_repr)


def _get_linear_subclass_inserter(
constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs
):
Expand Down Expand Up @@ -1373,11 +1387,22 @@ def _int8_weight_only_transform(
"applying int8 weight only quant requires module to have {parameter_name} attribute"
+ " but {module} does not have one"
)
new_weight = _int8_weight_only_quantize_tensor(
quantized_tensor = _int8_weight_only_quantize_tensor(
getattr(module, parameter_name), config
)
setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False))
module.extra_repr = types.MethodType(_linear_extra_repr, module)
setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_tensor, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


Expand Down Expand Up @@ -1662,16 +1687,23 @@ def _float8_weight_only_transform(
if isinstance(module, Float8Linear):
module = _unwrap_float8_linear(module)

new_weight = _float8_weight_only_quant_tensor(
quantized_tensor = _float8_weight_only_quant_tensor(
getattr(module, parameter_name), config
)

setattr(
module,
parameter_name,
torch.nn.Parameter(new_weight, requires_grad=False),
torch.nn.Parameter(quantized_tensor, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


Expand Down Expand Up @@ -1918,7 +1950,14 @@ def _float8_dynamic_activation_float8_weight_transform(
parameter_name,
torch.nn.Parameter(quantized_tensor, requires_grad=False),
)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


Expand Down
Loading