Skip to content

Commit d258a11

Browse files
authored
Config Serilization No Deps (#1875)
hand-rolled
1 parent 8c8388d commit d258a11

File tree

3 files changed

+409
-1
lines changed

3 files changed

+409
-1
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import json
2+
import os
3+
import tempfile
4+
from dataclasses import dataclass
5+
from unittest import mock
6+
7+
import pytest
8+
import torch
9+
10+
from torchao.core.config import (
11+
AOBaseConfig,
12+
VersionMismatchError,
13+
config_from_dict,
14+
config_to_dict,
15+
)
16+
from torchao.quantization.quant_api import (
17+
Float8DynamicActivationFloat8WeightConfig,
18+
Float8WeightOnlyConfig,
19+
FPXWeightOnlyConfig,
20+
GemliteUIntXWeightOnlyConfig,
21+
Int4DynamicActivationInt4WeightConfig,
22+
Int4WeightOnlyConfig,
23+
Int8DynamicActivationInt4WeightConfig,
24+
Int8DynamicActivationInt8WeightConfig,
25+
Int8WeightOnlyConfig,
26+
PerRow,
27+
UIntXWeightOnlyConfig,
28+
)
29+
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
30+
31+
# Define test configurations as fixtures
32+
configs = [
33+
Float8DynamicActivationFloat8WeightConfig(),
34+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
35+
Float8WeightOnlyConfig(
36+
weight_dtype=torch.float8_e4m3fn,
37+
),
38+
UIntXWeightOnlyConfig(dtype=torch.uint1),
39+
Int4DynamicActivationInt4WeightConfig(),
40+
Int4WeightOnlyConfig(
41+
group_size=32,
42+
),
43+
Int8DynamicActivationInt4WeightConfig(
44+
group_size=64,
45+
),
46+
Int8DynamicActivationInt8WeightConfig(),
47+
# Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()),
48+
Int8WeightOnlyConfig(
49+
group_size=128,
50+
),
51+
UIntXWeightOnlyConfig(
52+
dtype=torch.uint3,
53+
group_size=32,
54+
use_hqq=True,
55+
),
56+
GemliteUIntXWeightOnlyConfig(
57+
group_size=128, # Optional, has default of 64
58+
bit_width=8, # Optional, has default of 4
59+
packing_bitwidth=8, # Optional, has default of 32
60+
contiguous=True, # Optional, has default of None
61+
),
62+
FPXWeightOnlyConfig(ebits=4, mbits=8),
63+
# Sparsity configs
64+
SemiSparseWeightConfig(),
65+
BlockSparseWeightConfig(blocksize=128),
66+
]
67+
68+
69+
# Create ids for better test naming
70+
def get_config_ids(configs):
71+
if not isinstance(configs, list):
72+
configs = [configs]
73+
return [config.__class__.__name__ for config in configs]
74+
75+
76+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
77+
def test_reconstructable_dict_file_round_trip(config):
78+
"""Test saving and loading reconstructable dicts to/from JSON files."""
79+
# Get a reconstructable dict
80+
reconstructable = config_to_dict(config)
81+
82+
# Create a temporary file to save the JSON
83+
with tempfile.NamedTemporaryFile(
84+
mode="w+", suffix=".json", delete=False
85+
) as temp_file:
86+
# Write the reconstructable dict as JSON
87+
json.dump(reconstructable, temp_file)
88+
temp_file_path = temp_file.name
89+
90+
try:
91+
# Read back the JSON file
92+
with open(temp_file_path, "r") as file:
93+
loaded_dict = json.load(file)
94+
95+
# Reconstruct from the loaded dict
96+
reconstructed = config_from_dict(loaded_dict)
97+
98+
# Check it's the right class
99+
assert isinstance(reconstructed, config.__class__)
100+
101+
# Verify attributes match
102+
for attr_name in config.__dict__:
103+
if not attr_name.startswith("_"): # Skip private attributes
104+
original_value = getattr(config, attr_name)
105+
reconstructed_value = getattr(reconstructed, attr_name)
106+
107+
# Special handling for torch dtypes
108+
if (
109+
hasattr(original_value, "__module__")
110+
and original_value.__module__ == "torch"
111+
):
112+
assert (
113+
str(original_value) == str(reconstructed_value)
114+
), f"Attribute {attr_name} mismatch after file round trip for {config.__class__.__name__}"
115+
else:
116+
assert (
117+
original_value == reconstructed_value
118+
), f"Attribute {attr_name} mismatch after file round trip for {config.__class__.__name__}"
119+
120+
finally:
121+
# Clean up the temporary file
122+
if os.path.exists(temp_file_path):
123+
os.unlink(temp_file_path)
124+
125+
126+
# Define a dummy config in a non-allowed module
127+
@dataclass
128+
class DummyNonAllowedConfig(AOBaseConfig):
129+
VERSION = 2
130+
value: int = 42
131+
132+
133+
def test_disallowed_modules():
134+
"""Test that configs from non-allowed modules are rejected during reconstruction."""
135+
# Create a config from a non-allowed module
136+
dummy_config = DummyNonAllowedConfig()
137+
reconstructable = config_to_dict(dummy_config)
138+
139+
with pytest.raises(
140+
ValueError,
141+
match="Failed to find class DummyNonAllowedConfig in any of the allowed modules",
142+
):
143+
config_from_dict(reconstructable)
144+
145+
# Use mock.patch as a context manager
146+
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
147+
reconstructed = config_from_dict(reconstructable)
148+
assert isinstance(reconstructed, DummyNonAllowedConfig)
149+
assert reconstructed.value == 42
150+
assert reconstructed.VERSION == 2
151+
152+
153+
def test_version_mismatch():
154+
"""Test that version mismatch raises an error during reconstruction."""
155+
# Create a config
156+
dummy_config = DummyNonAllowedConfig()
157+
reconstructable = config_to_dict(dummy_config)
158+
159+
# Modify the version in the dict to create a mismatch
160+
reconstructable["_version"] = 1
161+
162+
# Patch to allow the module but should still fail due to version mismatch
163+
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
164+
with pytest.raises(
165+
VersionMismatchError,
166+
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
167+
):
168+
config_from_dict(reconstructable)
169+
170+
171+
if __name__ == "__main__":
172+
pytest.main([__file__])

0 commit comments

Comments
 (0)