Skip to content

Conversation

@jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented May 1, 2025

Summary:
Currently when we try to quantize input_embedding for some models, the output embedding (lm_head) will also be quantized the same way, since they are tied, and this may not be what we want. To break the tie, we added the option to allow people to

  1. load unquantized weight
  2. tie weights
  3. quantize

so that the tie will be broken

Test Plan:

from transformers import (
  AutoModelForCausalLM,
  AutoProcessor,
  AutoTokenizer,
  TorchAoConfig,
)
from torchao.quantization.quant_api import (
    IntxWeightOnlyConfig,
    Int8DynamicActivationIntxWeightConfig,
    AOPerModuleConfig
)
from torchao.quantization.granularity import PerGroup, PerAxis
import torch

model_id = "microsoft/Phi-4-mini-instruct"

embedding_config = IntxWeightOnlyConfig(
    weight_dtype=torch.int8,
    granularity=PerAxis(0),
)
linear_config = Int8DynamicActivationIntxWeightConfig(
    weight_dtype=torch.int4,
    weight_granularity=PerGroup(32),
    weight_scale_dtype=torch.bfloat16,
)
quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

print(quantized_model)
print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight)
print("lm head weight:", quantized_model.lm_head.weight)
from transformers.modeling_utils import find_tied_parameters
print(find_tied_parameters(quantized_model))

# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
messages = [
    {
        "role": "system",
        "content": "",
    },
    {"role": "user", "content": prompt},
]
templated_prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
print("Prompt:", prompt)
print("Templated prompt:", templated_prompt)
inputs = tokenizer(
    templated_prompt,
    return_tensors="pt",
).to("cuda")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", output_text[0][len(prompt):])

output

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(200064, 3072, padding_idx=199999)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f3973d8f250>, weight=AffineQuantizedTensor(shape=torch.Size([3072, 3072]), block_size=(1, 32), device=cuda:0, _layout=QDQLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
          (qkv_proj): Linear(in_features=3072, out_features=5120, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f3973d8f250>, weight=AffineQuantizedTensor(shape=torch.Size([5120, 3072]), block_size=(1, 32), device=cuda:0, _layout=QDQLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f3973d8f250>, weight=AffineQuantizedTensor(shape=torch.Size([16384, 3072]), block_size=(1, 32), device=cuda:0, _layout=QDQLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
          (down_proj): Linear(in_features=8192, out_features=3072, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f3973d8f250>, weight=AffineQuantizedTensor(shape=torch.Size([3072, 8192]), block_size=(1, 32), device=cuda:0, _layout=QDQLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): Phi3RMSNorm((3072,), eps=1e-05)
    (rotary_emb): Phi3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3072, out_features=200064, bias=False)
)
embed_tokens.weight: AffineQuantizedTensor(tensor_impl=QDQTensorImpl(data=tensor([[-20,   4,  13,  ...,   8,  -5,  -3],
        [ -2,   1,  13,  ...,   0, -18,  15],
        [  1,   2,  11,  ...,  15,   0,  18],
        ...,
        [  0,  -2,   7,  ...,   4,  10,  12],
        [  0,  -2,   7,  ...,   4,  10,  12],
        [  0,  -2,   7,  ...,   4,  10,  12]], device='cuda:0',
       dtype=torch.int8)... , scale=tensor([0.0083, 0.0099, 0.0115,  ..., 0.0009, 0.0009, 0.0009], device='cuda:0')... , zero_point=tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int8)... , _layout=QDQLayout()), block_size=(1, 3072), shape=torch.Size([200064, 3072]), device=cuda:0, dtype=torch.float32, requires_grad=False)
lm head weight: Parameter containing:
tensor([[-0.1689,  0.0317,  0.1060,  ...,  0.0635, -0.0378, -0.0260],
        [-0.0233,  0.0072,  0.1299,  ...,  0.0013, -0.1748,  0.1465],
        [ 0.0159,  0.0206,  0.1260,  ...,  0.1748, -0.0027,  0.2041],
        ...,
        [ 0.0002, -0.0020,  0.0062,  ...,  0.0038,  0.0095,  0.0113],
        [ 0.0002, -0.0020,  0.0062,  ...,  0.0038,  0.0095,  0.0113],
        [ 0.0002, -0.0020,  0.0062,  ...,  0.0038,  0.0095,  0.0113]],
       device='cuda:0')
[]
Prompt: Hey, are you conscious? Can you talk to me?
Templated prompt: <|system|><|end|><|user|>Hey, are you conscious? Can you talk to me?<|end|><|assistant|>
Response: Hello! As an AI, I don't have consciousness in the way humans do, but I'm fully operational and here to assist you. How can I help you today?

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:
Currently when we try to quantize input_embedding for some models, the output embedding
(lm_head) will also be quantized the same way, since they are tied, and this may not be what
we want. To break the tie, we added the option to allow people to
1. load unquantized weight
2. tie weights
3. quantize

so that the tie will be broken

Test Plan:
```
from transformers import (
  AutoModelForCausalLM,
  AutoProcessor,
  AutoTokenizer,
  TorchAoConfig,
)
from torchao.quantization.quant_api import (
    IntxWeightOnlyConfig,
    Int8DynamicActivationIntxWeightConfig,
    AOPerModuleConfig
)
from torchao.quantization.granularity import PerGroup, PerAxis
import torch

model_id = "microsoft/Phi-4-mini-instruct"

embedding_config = IntxWeightOnlyConfig(
    weight_dtype=torch.int8,
    granularity=PerAxis(0),
)
linear_config = Int8DynamicActivationIntxWeightConfig(
    weight_dtype=torch.int4,
    weight_granularity=PerGroup(32),
    weight_scale_dtype=torch.bfloat16,
)
quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

print(quantized_model)
print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight)
print("lm head weight:", quantized_model.lm_head.weight)
from transformers.modeling_utils import find_tied_parameters
print(find_tied_parameters(quantized_model))
```
Reviewers:

Subscribers:

Tasks:

Tags:
@github-actions github-actions bot marked this pull request as draft May 1, 2025 00:21
@github-actions
Copy link
Contributor

github-actions bot commented May 1, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@jerryzh168 jerryzh168 marked this pull request as ready for review May 1, 2025 02:39
@github-actions github-actions bot requested review from MekkCyber and SunMarc May 1, 2025 02:39
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this, happy to merge this if this unblocks you for now but let's clean this later. Added some comments

self.modules_to_not_convert = modules_to_not_convert
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
self.include_embedding = include_embedding
self.untie_embedding_weights = untie_embedding_weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to not save this key when serializing the model config as this is redundant with tie_word_embeddings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit different I think, together with tie_word_embeddings == True, this flag was specifically trying to achieve the order of

  1. load unquantized weight
  2. tie weights
  3. quantize (this will untie the weights)

while tie_word_embeddings == True without the flag means:

  1. load unquantized weight
  2. quantize
  3. tie weights (tied quantized weights)

and tie_word_embeddings == False means:

  1. load unquantized weight
  2. quantize
  3. don't tie weights

# 1. load the weight to model
# 2. run tie_weights to populate the weights
# 3. quantize
input_embed = model.get_input_embeddings()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution that would be a tad cleaner would be to not quantize the embeddings here and do it in _process_model_after_weight_loading m

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can see if we can explore this path, main constraint here is that we need to save a valid quantization_config together with quantized model, and people can load the model + quantization_config in some other frameworks like lm-eval, vllm

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jerryzh168
Copy link
Contributor Author

@SunMarc thanks Marc, yeah it would be good to merge to unblock us first, we can improve this a bit later

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good ! Just waiting for the CI so we can merge

@jerryzh168
Copy link
Contributor Author

@SunMarc @MekkCyber can you merge this

@MekkCyber
Copy link
Contributor

MekkCyber commented May 1, 2025

Can't merge it now because the status of build_pr_documentation is not reported yet. Will ask a core maintainer to merge it if the issue persists first thing tomorrow CET

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, might be good for other quantization methods!

@ArthurZucker ArthurZucker merged commit fa3c3f9 into huggingface:main May 2, 2025
20 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
Summary:
Currently when we try to quantize input_embedding for some models, the output embedding
(lm_head) will also be quantized the same way, since they are tied, and this may not be what
we want. To break the tie, we added the option to allow people to
1. load unquantized weight
2. tie weights
3. quantize

so that the tie will be broken

Test Plan:
```
from transformers import (
  AutoModelForCausalLM,
  AutoProcessor,
  AutoTokenizer,
  TorchAoConfig,
)
from torchao.quantization.quant_api import (
    IntxWeightOnlyConfig,
    Int8DynamicActivationIntxWeightConfig,
    AOPerModuleConfig
)
from torchao.quantization.granularity import PerGroup, PerAxis
import torch

model_id = "microsoft/Phi-4-mini-instruct"

embedding_config = IntxWeightOnlyConfig(
    weight_dtype=torch.int8,
    granularity=PerAxis(0),
)
linear_config = Int8DynamicActivationIntxWeightConfig(
    weight_dtype=torch.int4,
    weight_granularity=PerGroup(32),
    weight_scale_dtype=torch.bfloat16,
)
quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

print(quantized_model)
print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight)
print("lm head weight:", quantized_model.lm_head.weight)
from transformers.modeling_utils import find_tied_parameters
print(find_tied_parameters(quantized_model))
```
Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mohamed Mekkouri <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants