Skip to content

Commit 1d09990

Browse files
authored
feat(atenlib): aten function signature generator (#212)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #235 * #223 * __->__ #212
1 parent ed8dfb6 commit 1d09990

File tree

4 files changed

+338
-2
lines changed

4 files changed

+338
-2
lines changed

.flake8

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ exclude =
1818
.eggs,
1919
**test/models/*.py,
2020
**onnx_backend_test_code/*.py,
21+
22+
per-file-ignores =
23+
onnxscript/function_libs/torch_aten/ops/*:E501
Lines changed: 331 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,334 @@
1+
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
16
"""Generates the ATen signatures for the ONNX ATen operator set using torch.ops."""
7+
from __future__ import annotations
8+
9+
import argparse
10+
import ast
11+
import logging
12+
import os
13+
import textwrap
14+
import typing
15+
from typing import Any, Sequence
16+
17+
import torchgen.gen
18+
import torchgen.model
19+
import yaml
20+
21+
import opgen.pygen as cg
22+
23+
24+
def load_native_function_yaml(yaml_path: str):
25+
with open(yaml_path, encoding="utf-8") as f:
26+
yaml_str = f.read()
27+
with open(yaml_path, encoding="utf-8") as f:
28+
all_functions = yaml.safe_load(f)
29+
valid_tags = set()
30+
# Mark all tags as valid, since we don't want to validate them.
31+
for func in all_functions:
32+
if "tags" not in func:
33+
continue
34+
valid_tags.add(func["tags"])
35+
36+
return yaml_str, valid_tags
37+
38+
39+
def parse_native_functions_yaml(yaml_path: str) -> tuple[Any, Any]:
40+
"""Parses the native_functions.yaml file."""
41+
yaml_str, valid_tags = load_native_function_yaml(yaml_path)
42+
yaml_struct = yaml.load(yaml_str, Loader=torchgen.gen.LineLoader) # noqa: DUO109
43+
parsed = torchgen.gen.parse_native_yaml_struct(
44+
yaml_struct, valid_tags, path=yaml_path, skip_native_fns_gen=True
45+
)
46+
return parsed.native_functions, parsed.backend_indices
47+
48+
49+
def create_list_type(arg: torchgen.model.Argument) -> cg.TypeRef:
50+
assert isinstance(arg.type, torchgen.model.ListType), f"arg: {arg}"
51+
arg_type = arg_type_to_str(arg.type)
52+
if type_is_builtin(arg_type):
53+
return cg.TypingRefs.Sequence(cg.BuiltinTypeRef(arg_type))
54+
if arg_type == "TensorType":
55+
return cg.TypingRefs.Sequence(cg.TypeRef("onnxscript.onnx_types", "TensorType"))
56+
return cg.TypeRef("onnxscript", arg_type)
57+
58+
# TODO(justinchuby): Enable this when generics are better supported
59+
# if arg.type.size is None:
60+
# # INT64[...]
61+
# return cg.TypeRef("onnxscript", arg_type, cg.EllipsisTypeRef())
62+
# # INT64[3]
63+
# return cg.TypeRef("onnxscript", arg_type, *[cg.TypeRef(None, f"{arg.type.size}")])
64+
65+
66+
def arg_type_to_str(arg_type: torchgen.model.Type) -> str:
67+
if arg_type.is_base_ty_like(torchgen.model.BaseTy.Tensor):
68+
return "TensorType"
69+
elif arg_type.is_base_ty_like(torchgen.model.BaseTy.SymInt):
70+
return "INT64"
71+
elif arg_type.is_base_ty_like(torchgen.model.BaseTy.Scalar):
72+
return "float"
73+
elif arg_type.is_base_ty_like(torchgen.model.BaseTy.float):
74+
return "float"
75+
elif arg_type.is_base_ty_like(torchgen.model.BaseTy.int):
76+
return "int"
77+
elif arg_type.is_base_ty_like(torchgen.model.BaseTy.bool):
78+
return "bool"
79+
elif arg_type.is_base_ty_like(torchgen.model.BaseTy.str):
80+
return "str"
81+
elif arg_type.is_base_ty_like(torchgen.model.BaseTy.ScalarType):
82+
return "int"
83+
else:
84+
# Anything unhandled is a string option.
85+
return "str"
86+
87+
88+
def type_is_builtin(arg_type: str) -> bool:
89+
"""Returns whether the given type is a python builtin type (that we care about)."""
90+
return arg_type in {"float", "int", "bool", "str"}
91+
92+
93+
def get_argument_type(arg: torchgen.model.Argument) -> cg.TypeRef:
94+
"""Returns the Python type for the given argument."""
95+
if isinstance(arg.type, torchgen.model.ListType):
96+
inner_node = create_list_type(arg)
97+
else:
98+
arg_type_str = arg_type_to_str(arg.type)
99+
if type_is_builtin(arg_type_str):
100+
inner_node = cg.BuiltinTypeRef(arg_type_str)
101+
elif arg_type_str == "TensorType":
102+
inner_node = cg.TypeRef("onnxscript.onnx_types", "TensorType")
103+
else:
104+
inner_node = cg.TypeRef("onnxscript", arg_type_str)
105+
106+
if arg.type.is_nullable():
107+
return cg.TypingRefs.Optional(inner_node)
108+
if arg.default is not None and parse_default_value(arg) is None:
109+
return cg.TypingRefs.Optional(inner_node)
110+
return inner_node
111+
112+
113+
def should_generate_signature(func: torchgen.model.NativeFunction) -> bool:
114+
"""Returns whether the signature for the given function should be generated."""
115+
if func.func.name.name.base.startswith("_"):
116+
return False
117+
if func.func.name.name.inplace:
118+
return False
119+
if func.func.name.overload_name and func.func.name.overload_name != "Tensor":
120+
# Ignore overloads for now.
121+
# Some ops only have overloaded versions, like aten::add.Tensor. And we
122+
# want to generate the aten::add op.
123+
return False
124+
return True
125+
126+
127+
def get_op_name(func: torchgen.model.NativeFunction) -> str:
128+
if func.func.name.overload_name and func.func.name.overload_name != "Tensor":
129+
# Do not include the overload name if it is "Tensor", since ops like
130+
# aten::add.Tensor is what we want for aten::add.
131+
name = f"{func.func.name.name.base}_{func.func.name.overload_name}"
132+
else:
133+
name = f"{func.func.name.name.base}"
134+
135+
# Prefix with namespace to avoid name conflicts with other operators and arguments.
136+
return f"{func.namespace}_{name}"
137+
138+
139+
def parse_default_value(arg: torchgen.model.Argument) -> Any:
140+
default = arg.default
141+
assert default is not None, f"arg: {arg}"
142+
if default.startswith("[") and default.endswith("]"):
143+
# Convert list to tuple
144+
default_val = ast.literal_eval(default)
145+
assert isinstance(default_val, list)
146+
if not default_val:
147+
# Empty list is represented as None.
148+
return None
149+
return tuple(default_val)
150+
# Special case for reduction=Mean
151+
if default == "Mean":
152+
return 1
153+
154+
try:
155+
value = ast.literal_eval(default)
156+
if isinstance(value, int):
157+
# Expand the value to a tuple if the type is a list.
158+
if isinstance(arg.type, torchgen.model.ListType):
159+
if arg.type.size is not None:
160+
return (value,) * arg.type.size
161+
return (value,)
162+
return value
163+
except ValueError:
164+
# Treat it as a string.
165+
return default.lower()
166+
167+
168+
def create_return_type(returns: Sequence[torchgen.model.Return]) -> cg.TypeRef:
169+
"""Returns the Python type for the return value of the given function."""
170+
if not returns:
171+
return cg.TypingRefs.Any()
172+
return_nodes = []
173+
for return_val in returns:
174+
return_type = return_val.type
175+
return_type_str = arg_type_to_str(return_type)
176+
if type_is_builtin(return_type_str):
177+
# Python type
178+
return_node: cg.TypeRef = cg.BuiltinTypeRef(return_type_str)
179+
elif return_type_str == "TensorType":
180+
return_node = cg.TypeRef("onnxscript.onnx_types", "TensorType")
181+
else:
182+
return_node = cg.TypeRef("onnxscript", arg_type_to_str(return_type))
183+
if return_type.is_nullable():
184+
return_node = cg.TypingRefs.Optional(return_node)
185+
return_nodes.append(return_node)
186+
if len(return_nodes) == 1:
187+
return return_nodes[0]
188+
return cg.BuiltinTypeRef("tuple", *return_nodes)
189+
190+
191+
def format_arg_name(arg: torchgen.model.Argument) -> str:
192+
"""Returns the python compatible name of the given argument."""
193+
if arg.name == "from":
194+
return f"{arg.name}_"
195+
return arg.name # type: ignore[no-any-return]
196+
197+
198+
def create_signature(func: torchgen.model.NativeFunction) -> cg.FunctionDef:
199+
"""Creates the signature for the given function."""
200+
op_name = get_op_name(func)
201+
args = [
202+
arg.argument if isinstance(arg, torchgen.model.SelfArgument) else arg
203+
for arg in func.func.arguments.positional
204+
]
205+
kwargs = [
206+
arg
207+
for arg in func.func.arguments.kwarg_only
208+
if not isinstance(arg, torchgen.model.TensorOptionsArguments)
209+
]
210+
211+
py_args = [
212+
cg.Arg(
213+
format_arg_name(arg),
214+
get_argument_type(arg),
215+
default_value=cg.Constant(parse_default_value(arg))
216+
if arg.default is not None
217+
else None,
218+
)
219+
for arg in args
220+
]
221+
if kwargs:
222+
# Arguments after this point are keyword-only.
223+
py_args += [
224+
cg.Arg(
225+
format_arg_name(kwarg),
226+
get_argument_type(kwarg),
227+
default_value=cg.Constant(parse_default_value(kwarg))
228+
if kwarg.default is not None
229+
else None,
230+
is_kwarg=True,
231+
)
232+
for kwarg in kwargs
233+
]
234+
235+
return cg.FunctionDef(
236+
op_name,
237+
*py_args,
238+
return_type=create_return_type(func.func.returns),
239+
body=[
240+
cg.ThunkStmt(f"# {func.func}"),
241+
cg.Raise(cg.Call(cg.Name("NotImplementedError"))), # type: ignore[list-item]
242+
],
243+
)
244+
245+
246+
def create_onnx_function_module(
247+
functions: Sequence[torchgen.model.NativeFunction],
248+
) -> cg.Module:
249+
"""Creates the onnx function module."""
250+
return cg.Module(
251+
cg.ImportFrom("__future__", cg.Alias("annotations")),
252+
*[create_signature(func) for func in functions if should_generate_signature(func)],
253+
)
254+
255+
256+
def copyright_header() -> str:
257+
"""Creates the copyright header."""
258+
dashline = f"# {'-' * 74}"
259+
return textwrap.dedent(
260+
f"""\
261+
{dashline}
262+
# Copyright (c) Microsoft Corporation. All rights reserved.
263+
# Licensed under the MIT License.
264+
{dashline}
265+
# mypy: disable-error-code=misc
266+
# mypy: disable-error-code=type-arg
267+
# mypy: disable-error-code=valid-type
268+
# mypy: disable-error-code=assignment
269+
"""
270+
)
271+
272+
273+
def main(args: argparse.Namespace) -> None:
274+
native_functions, _ = parse_native_functions_yaml(args.yaml)
275+
functions: dict[str, dict[str, torchgen.model.NativeFunction]] = {}
276+
for func in native_functions:
277+
if not should_generate_signature(func):
278+
continue
279+
280+
module_name = typing.cast(str, func.python_module)
281+
if not module_name:
282+
module_name = "core"
283+
if module_name not in functions:
284+
functions[module_name] = {}
285+
op_name = get_op_name(func)
286+
if op_name in functions[module_name]:
287+
logging.warning(
288+
"Duplicated function: %s, overload: %s", op_name, func.func.name.overload_name
289+
)
290+
continue
291+
functions[module_name][op_name] = func
292+
293+
os.makedirs(args.outdir, exist_ok=True)
294+
295+
for module_name, module_functions in functions.items():
296+
sorted_functions = sorted(module_functions.items(), key=lambda x: x[0])
297+
py_module = create_onnx_function_module([func for _, func in sorted_functions])
298+
py_module.accept(cg.ImportAdjuster())
299+
py_module.accept(cg.DocCommentBuilder())
300+
output_path = os.path.join(args.outdir, f"{module_name}.py")
301+
302+
print(f"Generating {output_path}")
303+
with open(output_path, "w", encoding="utf-8") as f:
304+
f.write(copyright_header())
305+
# Add docstring
306+
f.write(
307+
textwrap.dedent(
308+
f'''\
309+
"""torch.ops.aten operators under the `{module_name}` module.
310+
311+
- No inplace operators.
312+
- All functions should not have the script() decorator. This is because
313+
we want to delay the compilation of the function.
314+
"""
315+
'''
316+
)
317+
)
318+
py_module.accept(cg.PythonWriter(f))
319+
print("Done.")
2320

3321

4-
def main():
5-
pass
322+
if __name__ == "__main__":
323+
parser = argparse.ArgumentParser()
324+
parser.add_argument(
325+
"--yaml",
326+
type=str,
327+
help="Path to PyTorch aten/src/ATen/native/native_functions.yaml",
328+
)
329+
parser.add_argument(
330+
"--outdir",
331+
type=str,
332+
help="Output directory for generated modules",
333+
)
334+
main(parser.parse_args())

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ module = [
7171
"onnxruntime.*",
7272
"autopep8.*",
7373
"parameterized.*",
74+
"torchgen.*",
7475
]
7576
ignore_missing_imports = true
7677

requirements-dev.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ sphinx<=5.3.0
1111
sphinx-gallery
1212
pydata_sphinx_theme
1313

14+
# ATen lib
15+
types-PyYAML
16+
1417
# Testing
1518
pytest!=7.1.0
1619
pytest-cov

0 commit comments

Comments
 (0)