Skip to content

Conversation

@andrewor14
Copy link
Contributor

Summary: TorchAoConfig optionally contains a
torchao.dtypes.Layout object which is a dataclass and not JSON serializable, and so the following fails:

import json
from torchao.dtypes import TensorCoreTiledLayout
from transformers import TorchAoConfig

config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout())

config.to_json_string()

json.dumps(config.to_dict())

This also causes quantized_model.save_pretrained(...) to fail because the first step of this call is to JSON serialize the config. Fixes pytorch/ao#1704.

Test Plan:
python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable

**Summary:** TorchAoConfig optionally contains a
`torchao.dtypes.Layout` object which is a dataclass and not
JSON serializable, and so the following fails:

```
import json
from torchao.dtypes import TensorCoreTiledLayout
from transformers import TorchAoConfig

config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout())

config.to_json_string()

json.dumps(config.to_dict())
```

This also causes `quantized_model.save_pretrained(...)` to
fail because the first step of this call is to JSON serialize
the config. Fixes pytorch/ao#1704.

**Test Plan:**
python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable
@andrewor14
Copy link
Contributor Author

@SunMarc @ArthurZucker @jerryzh168 Please take a look, thanks!

@jiqing-feng
Copy link
Contributor

Yes, it solves my issue. Thanks!

@MekkCyber
Copy link
Contributor

LGTM @andrewor14 thanks for the fix !

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice !

@SunMarc SunMarc merged commit fdcfdbf into huggingface:main Feb 18, 2025
25 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Feb 21, 2025
**Summary:** TorchAoConfig optionally contains a
`torchao.dtypes.Layout` object which is a dataclass and not
JSON serializable, and so the following fails:

```
import json
from torchao.dtypes import TensorCoreTiledLayout
from transformers import TorchAoConfig

config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout())

config.to_json_string()

json.dumps(config.to_dict())
```

This also causes `quantized_model.save_pretrained(...)` to
fail because the first step of this call is to JSON serialize
the config. Fixes pytorch/ao#1704.

**Test Plan:**
python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable

Co-authored-by: Mohamed Mekkouri <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

int4_weight_only api got error when saving transformers models

5 participants