-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Add option for ao base configs #36526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f527770 to
c39df74
Compare
| if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): | ||
| from accelerate.utils import CustomDtype | ||
|
|
||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all configs map cleanly to CustomDtype, could someone tell me what should be used here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use use the closest dtype that the quantized model will have in average ! This is used to calculate approximately the size of the model so that we can dispatch the weights correctly on each gpus.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, hmm this is a little hard to do in general without enforcing extra info, I can do a fuzzy match based of class.name since we have a pretty standard scheme there, is it okay if we default to torch.int8 when it doesn't match, or do things blow up? Since all our techniques are int8 or less
79c7b13 to
2d02d5d
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update ! Left a couple of comments. If it makes sense in the future, we can also force the user to use torchao >= 0.10.0 after a deprecation cycle, so that we don't have to maintain the old behavior.
| if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): | ||
| from accelerate.utils import CustomDtype | ||
|
|
||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use use the closest dtype that the quantized model will have in average ! This is used to calculate approximately the size of the model so that we can dispatch the weights correctly on each gpus.
|
Thanks for the update ! |
e5e7c32 to
fa21f9a
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing everything. One last thing would be to make sure that autoquant still works with the latest torchao.
fa21f9a to
53e7f2c
Compare
|
|
||
| > **⚠️ DEPRECATION WARNING** | ||
| > | ||
| > Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember it's fine for transformers to always depend on the most recent torchao versions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, we would like to support older version of torchao too but I feel like for now, it's fine for the user to download the most recent version of torchao.
| if isinstance(quant_type, AOBaseConfig): | ||
| # Extract size digit using fuzzy match on the class name | ||
| config_name = quant_type.__class__.__name__ | ||
| size_digit = fuzzy_match_size(config_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems a bit fragile? e.g. what would it look like for mx, fp4 etc.
| if size_digit == "4": | ||
| return CustomDtype.INT4 | ||
| else: | ||
| # Default to int8 | ||
| return torch.int8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if these are really needed, cc @SunMarc when are these used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are used when calculating the appropriate device_map (e.g. to know how to dispatch the layers in the different gpus). This is needed in torchao case as the model architecture is not changed prior to calculating the device_map.
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating !
|
Hey @drisspg @SunMarc, after this PR I get an error with this code: from transformers import AutoModelForCausalLM, TorchAoConfig
quantization_config = TorchAoConfig(quant_type="int8_weight_only")
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m", device_map=0, quantization_config=quantization_config
)It raises: This is because If there is no important reason to use the new name, I would suggest to revert to the old name because other libraries may also rely on that name (for instance PEFT would otherwise need to be updated). |
|
I will do a quick pr to revert to the old name ! |
|
Fixed it here #36849 |
Summary: We add the new torchao API support in hf transformers: huggingface#36526 one thing that's missing is it does not account for int4 weight only quant config only works on cuda, this PR adds back the workaround also updated the version requirement to > 0.9 temporarily so that we can use the torchao nightly before 0.10 is released, we should chagne this back before land Test Plan: local test: https://gist.github.com/jerryzh168/0e749d0dab40e2a62a7f2e48639f77b5 (we can setup deserialization test later when we can quantize a small model and host in a stable place like TinyLlama/TinyLlama-1.1B-Chat-v1.0) Reviewers: Subscribers: Tasks: Tags:
Summary: We add the new torchao API support in hf transformers: huggingface#36526 one thing that's missing is it does not account for int4 weight only quant config only works on cuda, this PR adds back the workaround also updated the version requirement to > 0.9 temporarily so that we can use the torchao nightly before 0.10 is released, we should chagne this back before land Test Plan: local test: https://gist.github.com/jerryzh168/0e749d0dab40e2a62a7f2e48639f77b5 (we can setup deserialization test later when we can quantize a small model and host in a stable place like TinyLlama/TinyLlama-1.1B-Chat-v1.0) Reviewers: Subscribers: Tasks: Tags:
| assert ( | ||
| len(quant_type) == 1 and "default" in quant_type | ||
| ), "Expected only one key 'default' in quant_type dictionary" | ||
| quant_type = quant_type["default"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one breaks
| def check_serialization_expected_output(self, device, expected_output): |
What does this PR do?
This PR updates the TorchAO integration, in release 0.9 we upgraded to specifying quantization config via configs and this allows for these configs to be serialized to the config.json when saving the model.
Main Changes
Enhanced Configuration Support:
TorchAoConfigto accept two types of configurations:AOBaseConfigobject instances for more advanced configuration the new blessed pathSerialization & Deserialization:
AOBaseConfigobjectsto_dict()andfrom_dict()methodsThe serialization format uses a structured dictionary with:
{ "_type": "ConfigClassName", # Class name, not full module path "_version": 1, # Version from the class's VERSION attribute "_data": { # Actual configuration parameters "param1": value1, "param2": value2, # Nested objects also get serialized with their types } }Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.