Skip to content

Conversation

@jcaip
Copy link
Contributor

@jcaip jcaip commented Nov 11, 2025

Summary:

This fixes extra_repr for generic modules by introducing a new helper function, _module_extra_repr, which will add TorchAOTensor info to all the parameters in a module that are an instance of TorchAOBaseTensor to the original extra_repr of the module.

The configs supporting parameter quantization have been update to use _module_extra_repr.

Also renamed new_weight -> quantized_tensor to be more consistent.

For example, we will see the following output

    (language_model): Qwen3VLMoeTextModel(
      (embed_tokens): Embedding(151936, 2048)
      (layers): ModuleList(
        (0-47): 48 x Qwen3VLMoeTextDecoderLayer(
          (self_attn): Qwen3VLMoeTextAttention(
            (q_proj): Linear(in_features=2048, out_features=4096, bias=False, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerRow(dim=-1), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[1, 2048], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([4096, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
            (k_proj): Linear(in_features=2048, out_features=512, bias=False, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerRow(dim=-1), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[1, 2048], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([512, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
            (v_proj): Linear(in_features=2048, out_features=512, bias=False, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerRow(dim=-1), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[1, 2048], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([512, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
            (o_proj): Linear(in_features=4096, out_features=2048, bias=False)
            (q_norm): Qwen3VLMoeTextRMSNorm((128,), eps=1e-06)
            (k_norm): Qwen3VLMoeTextRMSNorm((128,), eps=1e-06)
          )
          (mlp): Qwen3VLMoeTextSparseMoeBlock(
            (gate): Linear(in_features=2048, out_features=128, bias=False)
            (experts): Qwen3VLMoeTextExperts(
              gate_up_proj=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerRow(dim=-1), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[1, 2048, 1], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([128, 1, 1536]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>), down_proj=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerRow(dim=-1), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[1, 768, 1], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([128, 1, 2048]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)
              (act_fn): SiLUActivation()
            )
          )
          (input_layernorm): Qwen3VLMoeTextRMSNorm((2048,), eps=1e-06)
          (post_attention_layernorm): Qwen3VLMoeTextRMSNorm((2048,), eps=1e-06)
        )
      )
      (norm): Qwen3VLMoeTextRMSNorm((2048,), eps=1e-06)
      (rotary_emb): Qwen3VLMoeTextRotaryEmbedding()
    )
  )
  (lm_head): Linear(in_features=2048, out_features=151936, bias=False)

when we run the following code:

import logging

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig, Qwen3VLMoeForConditionalGeneration, AutoProcessor

from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    PerRow,
    FqnToConfig,
    quantize_
)

# Configure logging to see warnings and debug information
logging.basicConfig(
    level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
)

# Enable specific loggers that might contain the serialization warnings
logging.getLogger("transformers").setLevel(logging.INFO)
logging.getLogger("torchao").setLevel(logging.INFO)
logging.getLogger("safetensors").setLevel(logging.INFO)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)

config = Float8DynamicActivationFloat8WeightConfig(
    granularity=PerRow(),
)
expert_config = Float8DynamicActivationFloat8WeightConfig(
    granularity=[PerRow(), PerRow(1)],
)


# only quantize language model
quant_config = FqnToConfig({
    r"re:model.language_model.*.gate_up_proj": expert_config,
    r"re:model.language_model.*.down_proj": expert_config,
    r"re:model.language_model.*.q_proj": config,
    r"re:model.language_model.*.k_proj": config,
    r"re:model.language_model.*.v_proj": config,
})
quantization_config = TorchAoConfig(quant_type=quant_config)
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-30B-A3B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

print(model)

Test Plan:

pytest test/quantization/test_quant_api.py -k test_fqn_to_config_repr

Reviewers:

Subscribers:

Tasks:

Tags:

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 11, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3328

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 50db29f with merge base b4ec4cb (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jcaip jcaip added the topic: bug fix Use this tag for PRs that fix bugs label Nov 11, 2025
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a test we can add print(module) to to prevent regressions?

@jcaip jcaip changed the title Fix nn.Linear module repr for param quantization Add generic TorchAOTensor extra_repr for nn.Modules Nov 11, 2025
@jcaip jcaip requested a review from jerryzh168 November 11, 2025 21:18
"self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, "
"pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))"
)
assert str(custom_module) == expected_str
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be a bit fragile? since small changes will break it, maybe use FileCheck()?

"pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))"
)

assert str(linear_model.linear1) == expected_str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, had some comments inline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants