Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Mar 4, 2025

What does this PR do?

This PR updates the TorchAO integration, in release 0.9 we upgraded to specifying quantization config via configs and this allows for these configs to be serialized to the config.json when saving the model.

Main Changes

  • Enhanced Configuration Support:

    • Extended TorchAoConfig to accept two types of configurations:
      • String-based configurations (original approach for BC concerns)
      • New AOBaseConfig object instances for more advanced configuration the new blessed path
  • Serialization & Deserialization:

    • Added functionality to properly serialize and deserialize AOBaseConfig objects
    • Allows configs to be saved to disk, shared between applications, and versioned for compatibility
    • Implemented through new to_dict() and from_dict() methods

The serialization format uses a structured dictionary with:

{
    "_type": "ConfigClassName",  # Class name, not full module path
    "_version": 1,               # Version from the class's VERSION attribute
    "_data": {                   # Actual configuration parameters
        "param1": value1,
        "param2": value2,
        # Nested objects also get serialized with their types
    }
}

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@drisspg drisspg mentioned this pull request Mar 5, 2025
@drisspg drisspg mentioned this pull request Mar 6, 2025
2 tasks
@drisspg drisspg force-pushed the ao-base-configs branch 2 times, most recently from f527770 to c39df74 Compare March 12, 2025 22:54
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
from accelerate.utils import CustomDtype

return None
Copy link
Contributor Author

@drisspg drisspg Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all configs map cleanly to CustomDtype, could someone tell me what should be used here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use use the closest dtype that the quantized model will have in average ! This is used to calculate approximately the size of the model so that we can dispatch the weights correctly on each gpus.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, hmm this is a little hard to do in general without enforcing extra info, I can do a fuzzy match based of class.name since we have a pretty standard scheme there, is it okay if we default to torch.int8 when it doesn't match, or do things blow up? Since all our techniques are int8 or less

@drisspg drisspg force-pushed the ao-base-configs branch 3 times, most recently from 79c7b13 to 2d02d5d Compare March 13, 2025 00:08
@drisspg drisspg marked this pull request as ready for review March 13, 2025 00:08
@drisspg
Copy link
Contributor Author

drisspg commented Mar 13, 2025

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update ! Left a couple of comments. If it makes sense in the future, we can also force the user to use torchao >= 0.10.0 after a deprecation cycle, so that we don't have to maintain the old behavior.

if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
from accelerate.utils import CustomDtype

return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use use the closest dtype that the quantized model will have in average ! This is used to calculate approximately the size of the model so that we can dispatch the weights correctly on each gpus.

@MekkCyber
Copy link
Contributor

Thanks for the update !

@drisspg drisspg force-pushed the ao-base-configs branch 12 times, most recently from e5e7c32 to fa21f9a Compare March 17, 2025 21:08
@drisspg drisspg requested a review from SunMarc March 18, 2025 03:33
@drisspg drisspg requested a review from MekkCyber March 18, 2025 03:33
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing everything. One last thing would be to make sure that autoquant still works with the latest torchao.


> **⚠️ DEPRECATION WARNING**
>
> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it's fine for transformers to always depend on the most recent torchao versions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, we would like to support older version of torchao too but I feel like for now, it's fine for the user to download the most recent version of torchao.

if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit fragile? e.g. what would it look like for mx, fp4 etc.

Comment on lines +156 to +160
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
Copy link
Contributor

@jerryzh168 jerryzh168 Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if these are really needed, cc @SunMarc when are these used?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are used when calculating the appropriate device_map (e.g. to know how to dispatch the layers in the different gpus). This is needed in torchao case as the model architecture is not changed prior to calculating the device_map.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating !

@SunMarc SunMarc merged commit e8d9603 into huggingface:main Mar 19, 2025
21 checks passed
@BenjaminBossan
Copy link
Member

Hey @drisspg @SunMarc, after this PR I get an error with this code:

from transformers import AutoModelForCausalLM, TorchAoConfig

quantization_config = TorchAoConfig(quant_type="int8_weight_only")
model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m", device_map=0, quantization_config=quantization_config
)

It raises:

    model = AutoModelForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/models/auto/auto_factory.py", line 573, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/modeling_utils.py", line 272, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/modeling_utils.py", line 4442, in from_pretrained
    ) = cls._load_pretrained_model(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/modeling_utils.py", line 4871, in _load_pretrained_model
    disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/modeling_utils.py", line 853, in _load_state_dict_into_meta_model
    hf_quantizer.create_quantized_param(
  File "/home/name/work/forks/transformers/src/transformers/quantizers/quantizer_torchao.py", line 239, in create_quantized_param
    quantize_(module, self.quantization_config.get_apply_tensor_subclass(), set_inductor_config=False)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TorchAoConfig' object has no attribute 'get_apply_tensor_subclass'

This is because get_apply_tensor_subclass was renamed to get_quantize_config but the old name is still being used:

https://github.com/huggingface/transformers/blob/1ddb64937cd31bd25df3213b1dc275396ef695cd/src/transformers/quantizers/quantizer_torchao.py#L239C56-L239C81

If there is no important reason to use the new name, I would suggest to revert to the old name because other libraries may also rely on that name (for instance PEFT would otherwise need to be updated).

@SunMarc
Copy link
Member

SunMarc commented Mar 20, 2025

I will do a quick pr to revert to the old name !

@SunMarc
Copy link
Member

SunMarc commented Mar 20, 2025

Fixed it here #36849

jerryzh168 added a commit to jerryzh168/transformers that referenced this pull request Mar 25, 2025
Summary:
We add the new torchao API support in hf transformers: huggingface#36526
one thing that's missing is it does not account for int4 weight only quant config only works on cuda, this
PR adds back the workaround

also updated the version requirement to > 0.9 temporarily so that we can use the torchao nightly before 0.10
is released, we should chagne this back before land

Test Plan:
local test: https://gist.github.com/jerryzh168/0e749d0dab40e2a62a7f2e48639f77b5
(we can setup deserialization test later when we can quantize a small model and host in a stable place
like TinyLlama/TinyLlama-1.1B-Chat-v1.0)

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/transformers that referenced this pull request Mar 25, 2025
Summary:
We add the new torchao API support in hf transformers: huggingface#36526
one thing that's missing is it does not account for int4 weight only quant config only works on cuda, this
PR adds back the workaround

also updated the version requirement to > 0.9 temporarily so that we can use the torchao nightly before 0.10
is released, we should chagne this back before land

Test Plan:
local test: https://gist.github.com/jerryzh168/0e749d0dab40e2a62a7f2e48639f77b5
(we can setup deserialization test later when we can quantize a small model and host in a stable place
like TinyLlama/TinyLlama-1.1B-Chat-v1.0)

Reviewers:

Subscribers:

Tasks:

Tags:
assert (
len(quant_type) == 1 and "default" in quant_type
), "Expected only one key 'default' in quant_type dictionary"
quant_type = quant_type["default"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one breaks

def check_serialization_expected_output(self, device, expected_output):

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants