-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathschema_check.py
291 lines (247 loc) · 10.2 KB
/
schema_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import hashlib
import re
import typing
from enum import IntEnum
from typing import Any, Dict, Optional, Union
from torch._export.serde import schema
from torch._export.serde.union import _Union
class SchemaUpdateError(Exception):
pass
def _check(x, msg):
if not x:
raise SchemaUpdateError(msg)
def _staged_schema():
ret: Dict[str, Any] = {}
defs = {}
def _handle_aggregate(ty):
def dump_type(t):
if isinstance(t, type):
return t.__name__
elif isinstance(t, str):
assert t in defs
return t
elif o := typing.get_origin(t):
# Lemme know if there's a better way to do this.
if o == list:
head = "List"
elif o == dict:
head = "Dict"
elif o == tuple:
if typing.get_args(t) == ():
return "Tuple[()]"
head = "Tuple"
elif o == Union:
args = typing.get_args(t)
assert len(args) == 2 and args[1] == type(None)
return f"Optional[{dump_type(args[0])}]"
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
return (
f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]"
)
elif t == ():
return "()"
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
def dump_field(f):
t = dump_type(f.type)
ret = {"type": t}
value = dataclasses.MISSING
if f.default is not dataclasses.MISSING:
value = f.default
elif f.default_factory is not dataclasses.MISSING:
value = f.default_factory()
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
if value is not dataclasses.MISSING:
default = str(value)
ret["default"] = default
return ret
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
def _handle_int_enum(name, ty):
ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
def _handle_struct(name, ty):
ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)}
def _handle_union(name, ty):
ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)}
for name in dir(schema):
if name.startswith("_"):
continue
value = getattr(schema, name)
if hasattr(value, "__module__") and value.__module__ != schema.__name__:
continue
defs[name] = value
for name, value in defs.items():
if isinstance(value, type):
if issubclass(value, IntEnum):
_handle_int_enum(name, value)
elif dataclasses.is_dataclass(value):
if issubclass(value, _Union):
_handle_union(name, value)
else:
_handle_struct(name, value)
else:
raise AssertionError(f"Unknown schema type {name}: {value}")
elif isinstance(value, (int, tuple)):
assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION")
else:
raise AssertionError(f"Unknown variable {name}: {value}")
ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
assert all(x > 0 for x in ret["SCHEMA_VERSION"])
ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
assert ret["TREESPEC_VERSION"] > 0
return ret
def _diff_schema(dst, src):
additions = {key: src[key] for key in src.keys() - dst.keys()}
subtractions = {key: dst[key] for key in dst.keys() - src.keys()}
common_keys = src.keys() & dst.keys()
versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"}
common_keys -= versions
for key in common_keys:
src_kind = src[key]["kind"]
src_fields = src[key]["fields"]
dst_kind = dst[key]["kind"]
dst_fields = dst[key]["fields"]
_check(
src_kind == dst_kind,
f"Type {key} changed kind from {dst_kind} to {src_kind}",
)
assert isinstance(src_fields, dict) and isinstance(dst_fields, dict)
added_fields = {
key: src_fields[key] for key in src_fields.keys() - dst_fields.keys()
}
subtracted_fields = {
key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys()
}
common_fields = src_fields.keys() & dst_fields.keys()
for field in common_fields:
src_field = src_fields[field]
dst_field = dst_fields[field]
if src_kind == "struct":
_check(
src_field["type"] == dst_field["type"],
f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
)
if "default" in src_field and "default" not in dst_field:
added_fields[field] = {}
added_fields[field]["default"] = src_field["default"]
if "default" not in src_field and "default" in dst_field:
subtracted_fields[field] = {}
subtracted_fields[field]["default"] = dst_field["default"]
elif src_kind == "enum":
_check(
src_field == dst_field,
f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}",
)
elif src_kind == "union":
_check(
src_field["type"] == dst_field["type"],
f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
)
else:
raise AssertionError(f"Unknown kind {src_kind}: {key}")
if len(added_fields) > 0:
assert key not in additions
additions[key] = {}
additions[key]["fields"] = added_fields
if len(subtracted_fields) > 0:
assert key not in subtractions
subtractions[key] = {}
subtractions[key]["fields"] = subtracted_fields
return additions, subtractions
def _hash_schema(s):
return hashlib.sha256(repr(s).encode("utf-8")).hexdigest()
@dataclasses.dataclass
class _Commit:
result: Dict[str, Any]
checksum_result: str
path: str
additions: Dict[str, Any]
subtractions: Dict[str, Any]
base: Dict[str, Any]
checksum_base: Optional[str]
def update_schema():
import importlib.resources
if importlib.resources.is_resource(__package__, "schema.yaml"):
content = importlib.resources.read_text(__package__, "schema.yaml")
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
_check(match is not None, "checksum not found in schema.yaml")
assert match is not None
checksum_base = match.group(1)
from yaml import load, Loader
dst = load(content, Loader=Loader)
assert isinstance(dst, dict)
else:
checksum_base = None
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
src = _staged_schema()
additions, subtractions = _diff_schema(dst, src)
return _Commit(
result=src,
checksum_result=_hash_schema(src),
path=__package__.replace(".", "/") + "/schema.yaml",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base=checksum_base,
)
def check(commit: _Commit, force_unsafe: bool = False):
next_version = None
reason = ""
# Step 1: Detect major schema updates.
if len(commit.additions) > 0:
for k, v in commit.additions.items():
if k not in commit.base:
continue
kind = commit.result[k]["kind"]
fields = v["fields"]
for f, d in fields.items():
if "default" not in d and kind == "struct":
reason += (
f"Field {k}.{f} is added to schema.py without a default value as an incomparible change "
+ "which requires major version bump.\n"
)
next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
if len(commit.subtractions) > 0:
for k, v in commit.subtractions.items():
if k not in commit.result:
continue
for f in v["fields"]:
reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n"
next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
if force_unsafe:
reason += "--force-unsafe is used."
next_version = commit.result["SCHEMA_VERSION"]
else:
# Step 2: Detect minor schema updates.
if next_version is None and len(commit.additions) > 0:
for k, v in commit.additions.items():
for f in v["fields"]:
reason += (
f"Field {k}.{f} is added to schema.py as an compatible change "
+ "which still requires minor version bump.\n"
)
next_version = [
commit.base["SCHEMA_VERSION"][0],
commit.base["SCHEMA_VERSION"][1] + 1,
]
if next_version is None and len(commit.subtractions) > 0:
for k, v in commit.subtractions.items():
for f in v["fields"]:
reason += (
f"Field {k}.{f} is removed from schema.py as an compatible change "
+ "which still requires minor version bump.\n"
)
next_version = [
commit.base["SCHEMA_VERSION"][0],
commit.base["SCHEMA_VERSION"][1] + 1,
]
return next_version, reason