-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Break weight tying when quantizing input embedding #37905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Break weight tying when quantizing input embedding #37905
Conversation
Summary:
Currently when we try to quantize input_embedding for some models, the output embedding
(lm_head) will also be quantized the same way, since they are tied, and this may not be what
we want. To break the tie, we added the option to allow people to
1. load unquantized weight
2. tie weights
3. quantize
so that the tie will be broken
Test Plan:
```
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
TorchAoConfig,
)
from torchao.quantization.quant_api import (
IntxWeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
AOPerModuleConfig
)
from torchao.quantization.granularity import PerGroup, PerAxis
import torch
model_id = "microsoft/Phi-4-mini-instruct"
embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
)
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
weight_scale_dtype=torch.bfloat16,
)
quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(quantized_model)
print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight)
print("lm head weight:", quantized_model.lm_head.weight)
from transformers.modeling_utils import find_tied_parameters
print(find_tied_parameters(quantized_model))
```
Reviewers:
Subscribers:
Tasks:
Tags:
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this, happy to merge this if this unblocks you for now but let's clean this later. Added some comments
| self.modules_to_not_convert = modules_to_not_convert | ||
| self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs) | ||
| self.include_embedding = include_embedding | ||
| self.untie_embedding_weights = untie_embedding_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to not save this key when serializing the model config as this is redundant with tie_word_embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a bit different I think, together with tie_word_embeddings == True, this flag was specifically trying to achieve the order of
- load unquantized weight
- tie weights
- quantize (this will untie the weights)
while tie_word_embeddings == True without the flag means:
- load unquantized weight
- quantize
- tie weights (tied quantized weights)
and tie_word_embeddings == False means:
- load unquantized weight
- quantize
- don't tie weights
| # 1. load the weight to model | ||
| # 2. run tie_weights to populate the weights | ||
| # 3. quantize | ||
| input_embed = model.get_input_embeddings() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another solution that would be a tad cleaner would be to not quantize the embeddings here and do it in _process_model_after_weight_loading m
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we can see if we can explore this path, main constraint here is that we need to save a valid quantization_config together with quantized model, and people can load the model + quantization_config in some other frameworks like lm-eval, vllm
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@SunMarc thanks Marc, yeah it would be good to merge to unblock us first, we can improve this a bit later |
MekkCyber
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good ! Just waiting for the CI so we can merge
|
@SunMarc @MekkCyber can you merge this |
|
Can't merge it now because the status of build_pr_documentation is not reported yet. Will ask a core maintainer to merge it if the issue persists first thing tomorrow CET |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, might be good for other quantization methods!
Summary:
Currently when we try to quantize input_embedding for some models, the output embedding
(lm_head) will also be quantized the same way, since they are tied, and this may not be what
we want. To break the tie, we added the option to allow people to
1. load unquantized weight
2. tie weights
3. quantize
so that the tie will be broken
Test Plan:
```
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
TorchAoConfig,
)
from torchao.quantization.quant_api import (
IntxWeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
AOPerModuleConfig
)
from torchao.quantization.granularity import PerGroup, PerAxis
import torch
model_id = "microsoft/Phi-4-mini-instruct"
embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
)
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
weight_scale_dtype=torch.bfloat16,
)
quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(quantized_model)
print("embed_tokens.weight:", quantized_model.model.embed_tokens.weight)
print("lm head weight:", quantized_model.lm_head.weight)
from transformers.modeling_utils import find_tied_parameters
print(find_tied_parameters(quantized_model))
```
Reviewers:
Subscribers:
Tasks:
Tags:
Co-authored-by: Mohamed Mekkouri <[email protected]>
Summary:
Currently when we try to quantize input_embedding for some models, the output embedding (lm_head) will also be quantized the same way, since they are tied, and this may not be what we want. To break the tie, we added the option to allow people to
so that the tie will be broken
Test Plan:
output
Reviewers:
Subscribers:
Tasks:
Tags: