Skip to content

Commit 024259e

Browse files
andrewor14MekkCyberSunMarc
authored andcommitted
Fix TorchAoConfig not JSON serializable (huggingface#36206)
**Summary:** TorchAoConfig optionally contains a `torchao.dtypes.Layout` object which is a dataclass and not JSON serializable, and so the following fails: ``` import json from torchao.dtypes import TensorCoreTiledLayout from transformers import TorchAoConfig config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout()) config.to_json_string() json.dumps(config.to_dict()) ``` This also causes `quantized_model.save_pretrained(...)` to fail because the first step of this call is to JSON serialize the config. Fixes pytorch/ao#1704. **Test Plan:** python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable Co-authored-by: Mohamed Mekkouri <[email protected]> Co-authored-by: Marc Sun <[email protected]>
1 parent dd8c733 commit 024259e

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

src/transformers/utils/quantization_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
import copy
18+
import dataclasses
1819
import importlib.metadata
1920
import json
2021
import os
@@ -1539,6 +1540,21 @@ def __repr__(self):
15391540
config_dict = self.to_dict()
15401541
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
15411542

1543+
def to_dict(self) -> Dict[str, Any]:
1544+
"""
1545+
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
1546+
dataclasses to simple dicts.
1547+
1548+
Returns:
1549+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
1550+
"""
1551+
d = super().to_dict()
1552+
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
1553+
layout = d["quant_type_kwargs"]["layout"]
1554+
layout = dataclasses.asdict(layout)
1555+
d["quant_type_kwargs"]["layout"] = layout
1556+
return d
1557+
15421558

15431559
@dataclass
15441560
class BitNetConfig(QuantizationConfigMixin):

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,18 @@
3131
import torch
3232

3333
if is_torchao_available():
34-
from torchao.dtypes import AffineQuantizedTensor
35-
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
34+
from torchao.dtypes import (
35+
AffineQuantizedTensor,
36+
TensorCoreTiledLayout,
37+
)
3638

3739

3840
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
3941
weight = qlayer.weight
4042
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
4143
test_module.assertEqual(weight.quant_min, 0)
4244
test_module.assertEqual(weight.quant_max, 15)
43-
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
45+
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
4446

4547

4648
def check_forward(test_module, model, batch_size=1, context_size=1024):
@@ -82,6 +84,16 @@ def test_repr(self):
8284
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
8385
repr(quantization_config)
8486

87+
def test_json_serializable(self):
88+
"""
89+
Check that the config dict can be JSON serialized.
90+
"""
91+
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
92+
d = quantization_config.to_dict()
93+
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
94+
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
95+
quantization_config.to_json_string(use_diff=False)
96+
8597

8698
@require_torch_gpu
8799
@require_torchao

0 commit comments

Comments
 (0)