Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def create_quantized_param(
module._parameters[tensor_name] = torch.nn.Parameter(
param_value, requires_grad=param_value.requires_grad
).to(device=target_device)
# if we are quantizing tied parameters, to avoid tying the quantized weights
# the correct order to do it is
# 1. load the weight to model
# 2. run tie_weights to populate the weights
# 3. quantize
input_embed = model.get_input_embeddings()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution that would be a tad cleaner would be to not quantize the embeddings here and do it in _process_model_after_weight_loading m

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can see if we can explore this path, main constraint here is that we need to save a valid quantization_config together with quantized model, and people can load the model + quantization_config in some other frameworks like lm-eval, vllm

if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
model.tie_weights()
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)

# handle AOPerModuleConfig, introduced in torchao 0.11.0+
if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
from torchao.quantization import AOPerModuleConfig
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,7 @@ class TorchAoConfig(QuantizationConfigMixin):
modules_to_not_convert: Optional[List]
quant_type_kwargs: Dict[str, Any]
include_embedding: bool
untie_embedding_weights: bool

"""This is a config class for torchao quantization/sparsity techniques.

Expand All @@ -1569,6 +1570,9 @@ class TorchAoConfig(QuantizationConfigMixin):
inlcude_embedding (`bool`, default to `False`):
Whether to include embedding in quantization or not, input embedding will be removed from
the module_not_to_convert list as well if this flag is set.
untie_embedding_weights (`bool`, default to `False`):
Whether to untie the weights when we are quantizing input embedding weights that is tied
to other weights.
kwargs (`Dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
Expand Down Expand Up @@ -1614,13 +1618,15 @@ def __init__(
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List] = None,
include_embedding: bool = False,
untie_embedding_weights: bool = False,
**kwargs,
):
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
self.include_embedding = include_embedding
self.untie_embedding_weights = untie_embedding_weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to not save this key when serializing the model config as this is redundant with tie_word_embeddings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit different I think, together with tie_word_embeddings == True, this flag was specifically trying to achieve the order of

  1. load unquantized weight
  2. tie weights
  3. quantize (this will untie the weights)

while tie_word_embeddings == True without the flag means:

  1. load unquantized weight
  2. quantize
  3. tie weights (tied quantized weights)

and tie_word_embeddings == False means:

  1. load unquantized weight
  2. quantize
  3. don't tie weights

self.post_init()

@staticmethod
Expand Down