Skip to content

feat(atenlib): aten function signature generator #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2c8ee88
feat(atenlib): establish the aten-lib directory
justinchuby Nov 22, 2022
40af0b4
Update base for Update on "feat(atenlib): establish the aten-lib dire…
justinchuby Nov 22, 2022
ee554d7
Update on "feat(atenlib): establish the aten-lib directory"
justinchuby Nov 22, 2022
fa8ce18
feat(atenlib): Create sample functions and tests
justinchuby Nov 23, 2022
6d426b5
Update base for Update on "feat(atenlib): Create sample functions and…
justinchuby Nov 23, 2022
bb9fbee
Update on "feat(atenlib): Create sample functions and tests with OpInfo"
justinchuby Nov 23, 2022
0aed28e
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
91aa327
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
00a081b
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
fce8072
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
af56008
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8c0b370
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
bfeaefc
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
8dc3d15
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
46aa719
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
d87f309
Update on "feat(atenlib): create tests with OpInfo"
justinchuby Nov 23, 2022
afa50e3
feat: aten function signature generator
justinchuby Nov 29, 2022
1b2674d
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Nov 29, 2022
1861734
Update on "feat(atenlib): aten function signature generator"
justinchuby Nov 29, 2022
64810de
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Nov 29, 2022
421e9e4
Update on "feat(atenlib): aten function signature generator"
justinchuby Nov 29, 2022
c566e24
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Nov 29, 2022
15ecd59
Update on "feat(atenlib): aten function signature generator"
justinchuby Nov 29, 2022
2ab2fc3
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Nov 29, 2022
91fb612
Update on "feat(atenlib): aten function signature generator"
justinchuby Nov 29, 2022
d49a69d
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Nov 30, 2022
c7940b2
Update on "feat(atenlib): aten function signature generator"
justinchuby Nov 30, 2022
4c2c2c6
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Nov 30, 2022
ed6baf7
Update on "feat(atenlib): aten function signature generator"
justinchuby Nov 30, 2022
e5e4910
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Nov 30, 2022
1353d34
Update on "feat(atenlib): aten function signature generator"
justinchuby Nov 30, 2022
5cb196e
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Dec 5, 2022
b373478
Update on "feat(atenlib): aten function signature generator"
justinchuby Dec 5, 2022
75b2c9e
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Dec 6, 2022
fe5ce17
Update on "feat(atenlib): aten function signature generator"
justinchuby Dec 6, 2022
9fc6c1a
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Dec 6, 2022
ff48537
Update on "feat(atenlib): aten function signature generator"
justinchuby Dec 6, 2022
3f85493
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Dec 6, 2022
f1c2c3e
Update on "feat(atenlib): aten function signature generator"
justinchuby Dec 6, 2022
c33ad93
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby Dec 7, 2022
6c160e5
Update on "feat(atenlib): aten function signature generator"
justinchuby Dec 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions onnxscript/fuction_libs/tools/torch_aten/generate_aten_signatures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Generates the ATen signatures for the ONNX ATen operator set using torch.ops."""
from __future__ import annotations
import argparse
import sys
from typing import Any

import pygen as cg
import torchgen.gen
import torchgen.model
import yaml


# BaseTy = Enum(
# "BaseTy",
# (
# "Generator",
# "ScalarType",
# "Tensor",
# "int",
# "Dimname",
# "DimVector",
# "float",
# "str",
# "bool",
# "Layout",
# "Device",
# "Scalar",
# "MemoryFormat",
# "QScheme",
# "Storage",
# "Stream",
# "SymInt",
# "ConstQuantizerPtr",
# ),
# )


def load_native_function_yaml(yaml_path: str):
with open(yaml_path, encoding="utf-8") as f:
yaml_str = f.read()
with open(yaml_path, encoding="utf-8") as f:
all_functions = yaml.safe_load(f)
valid_tags = set()
for func in all_functions:
if "tags" not in func:
continue
valid_tags.add(func["tags"])

return yaml_str, valid_tags


def parse_native_functions_yaml(yaml_path: str) -> tuple[Any, Any]:
yaml_str, valid_tags = load_native_function_yaml(yaml_path)
yaml_struct = yaml.load(yaml_str, Loader=torchgen.gen.LineLoader)
parsed = torchgen.gen.parse_native_yaml_struct(
yaml_struct, valid_tags, path=yaml_path, skip_native_fns_gen=True
)
return parsed.native_functions, parsed.backend_indices


def get_argument_type(arg: torchgen.model.Argument) -> cg.TypeRef:
# TODO: Handel scalar type
optional = arg.type.is_nullable()
if arg.type.is_base_ty_like(torchgen.model.BaseTy.Tensor):
inner_type = cg.TypeRef(None, "Tensor")
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.SymInt):
# TODO(justinchuby): Make sure this is a scalar
inner_type = cg.TypeRef(None, "INT64")
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.float):
inner_type = cg.TypeRef(None, "float")
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.int):
inner_type = cg.TypeRef(None, "int")
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.bool):
inner_type = cg.TypeRef(None, "bool")
else:
inner_type = cg.TypeRef(None, "Any")

if optional:
return cg.TypeRef(None, "Optional", inner_type)
return inner_type


def should_generate_signature(func: torchgen.model.NativeFunction) -> bool:
"""Returns whether the signature for the given function should be generated."""
if func.func.name.name.base.startswith("_"):
return False
if func.func.name.overload_name:
# Ignore overloads for now.
return False
return True


def get_op_name(func: torchgen.model.NativeFunction) -> str:
if func.func.name.overload_name:
return f"{func.func.name.name.base}_{func.func.name.overload_name}"

return func.func.name.name.base


def create_signature(func: torchgen.model.NativeFunction) -> Any:
"""Creates the signature for the given function."""
print(func.namespace)
print(func.python_module)
print(func.func.name.overload_name)

op_name = get_op_name(func)
args = [
arg.argument if isinstance(arg, torchgen.model.SelfArgument) else arg
for arg in func.func.arguments.positional
]
kwargs = [
arg
for arg in func.func.arguments.kwarg_only
if not isinstance(arg, torchgen.model.TensorOptionsArguments)
]

py_args = [
cg.Arg(
arg.name,
get_argument_type(arg),
default_value=cg.ThunkExpr(arg.default) if arg.default is not None else None,
)
for arg in args
]
if kwargs:
py_args += [
# Arguments after this point are keyword-only.
cg.Arg(
"*",
)
] + [
cg.Arg(
kwarg.name,
get_argument_type(kwarg),
default_value=cg.ThunkExpr(kwarg.default or "None"),
is_kwarg=True,
)
for kwarg in kwargs
]

return cg.FunctionDef(
op_name,
*py_args,
return_type=None, # TODO: Add return type
body=[
cg.Raise(
cg.ThunkExpr("NotImplementedError"),
)
],
)


def main(args: argparse.Namespace) -> None:
native_functions, _ = parse_native_functions_yaml(args.native_functions_yaml)

for func in native_functions:
if not should_generate_signature(func):
continue

py_tree = create_signature(func)
py_tree.accept(cg.PythonWriter(sys.stdout))
print()

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--native-functions-yaml",
type=str,
default="/home/justinchu/dev/pytorch/aten/src/ATen/native/native_functions.yaml",
)
main(parser.parse_args())
Loading