-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Break weight tying when quantizing input embedding #37905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7b3cd1d
62ee0fb
8e23021
607de01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1555,6 +1555,7 @@ class TorchAoConfig(QuantizationConfigMixin): | |
| modules_to_not_convert: Optional[List] | ||
| quant_type_kwargs: Dict[str, Any] | ||
| include_embedding: bool | ||
| untie_embedding_weights: bool | ||
|
|
||
| """This is a config class for torchao quantization/sparsity techniques. | ||
|
|
||
|
|
@@ -1569,6 +1570,9 @@ class TorchAoConfig(QuantizationConfigMixin): | |
| inlcude_embedding (`bool`, default to `False`): | ||
| Whether to include embedding in quantization or not, input embedding will be removed from | ||
| the module_not_to_convert list as well if this flag is set. | ||
| untie_embedding_weights (`bool`, default to `False`): | ||
| Whether to untie the weights when we are quantizing input embedding weights that is tied | ||
| to other weights. | ||
| kwargs (`Dict[str, Any]`, *optional*): | ||
| The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments | ||
| `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in | ||
|
|
@@ -1614,13 +1618,15 @@ def __init__( | |
| quant_type: Union[str, "AOBaseConfig"], # noqa: F821 | ||
| modules_to_not_convert: Optional[List] = None, | ||
| include_embedding: bool = False, | ||
| untie_embedding_weights: bool = False, | ||
| **kwargs, | ||
| ): | ||
| self.quant_method = QuantizationMethod.TORCHAO | ||
| self.quant_type = quant_type | ||
| self.modules_to_not_convert = modules_to_not_convert | ||
| self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs) | ||
| self.include_embedding = include_embedding | ||
| self.untie_embedding_weights = untie_embedding_weights | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure to not save this key when serializing the model config as this is redundant with tie_word_embeddings
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a bit different I think, together with tie_word_embeddings == True, this flag was specifically trying to achieve the order of
while tie_word_embeddings == True without the flag means:
and tie_word_embeddings == False means:
|
||
| self.post_init() | ||
|
|
||
| @staticmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another solution that would be a tad cleaner would be to not quantize the embeddings here and do it in _process_model_after_weight_loading m
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we can see if we can explore this path, main constraint here is that we need to save a valid quantization_config together with quantized model, and people can load the model + quantization_config in some other frameworks like lm-eval, vllm