-
Notifications
You must be signed in to change notification settings - Fork 66
feat(atenlib): aten function signature generator #212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 17 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
2c8ee88
feat(atenlib): establish the aten-lib directory
justinchuby 40af0b4
Update base for Update on "feat(atenlib): establish the aten-lib dire…
justinchuby ee554d7
Update on "feat(atenlib): establish the aten-lib directory"
justinchuby fa8ce18
feat(atenlib): Create sample functions and tests
justinchuby 6d426b5
Update base for Update on "feat(atenlib): Create sample functions and…
justinchuby bb9fbee
Update on "feat(atenlib): Create sample functions and tests with OpInfo"
justinchuby 0aed28e
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby 91aa327
Update on "feat(atenlib): create tests with OpInfo"
justinchuby 00a081b
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby fce8072
Update on "feat(atenlib): create tests with OpInfo"
justinchuby af56008
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby 8c0b370
Update on "feat(atenlib): create tests with OpInfo"
justinchuby bfeaefc
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby 8dc3d15
Update on "feat(atenlib): create tests with OpInfo"
justinchuby 46aa719
Update base for Update on "feat(atenlib): create tests with OpInfo"
justinchuby d87f309
Update on "feat(atenlib): create tests with OpInfo"
justinchuby afa50e3
feat: aten function signature generator
justinchuby 1b2674d
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby 1861734
Update on "feat(atenlib): aten function signature generator"
justinchuby 64810de
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby 421e9e4
Update on "feat(atenlib): aten function signature generator"
justinchuby c566e24
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby 15ecd59
Update on "feat(atenlib): aten function signature generator"
justinchuby 2ab2fc3
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby 91fb612
Update on "feat(atenlib): aten function signature generator"
justinchuby d49a69d
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby c7940b2
Update on "feat(atenlib): aten function signature generator"
justinchuby 4c2c2c6
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby ed6baf7
Update on "feat(atenlib): aten function signature generator"
justinchuby e5e4910
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby 1353d34
Update on "feat(atenlib): aten function signature generator"
justinchuby 5cb196e
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby b373478
Update on "feat(atenlib): aten function signature generator"
justinchuby 75b2c9e
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby fe5ce17
Update on "feat(atenlib): aten function signature generator"
justinchuby 9fc6c1a
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby ff48537
Update on "feat(atenlib): aten function signature generator"
justinchuby 3f85493
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby f1c2c3e
Update on "feat(atenlib): aten function signature generator"
justinchuby c33ad93
Update base for Update on "feat(atenlib): aten function signature gen…
justinchuby 6c160e5
Update on "feat(atenlib): aten function signature generator"
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
171 changes: 171 additions & 0 deletions
171
onnxscript/fuction_libs/tools/torch_aten/generate_aten_signatures.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
"""Generates the ATen signatures for the ONNX ATen operator set using torch.ops.""" | ||
from __future__ import annotations | ||
import argparse | ||
import sys | ||
from typing import Any | ||
|
||
import pygen as cg | ||
import torchgen.gen | ||
|
||
import torchgen.model | ||
|
||
import yaml | ||
|
||
|
||
|
||
# BaseTy = Enum( | ||
# "BaseTy", | ||
# ( | ||
# "Generator", | ||
# "ScalarType", | ||
# "Tensor", | ||
# "int", | ||
# "Dimname", | ||
# "DimVector", | ||
# "float", | ||
# "str", | ||
# "bool", | ||
# "Layout", | ||
# "Device", | ||
# "Scalar", | ||
# "MemoryFormat", | ||
# "QScheme", | ||
# "Storage", | ||
# "Stream", | ||
# "SymInt", | ||
# "ConstQuantizerPtr", | ||
# ), | ||
# ) | ||
|
||
|
||
def load_native_function_yaml(yaml_path: str): | ||
with open(yaml_path, encoding="utf-8") as f: | ||
yaml_str = f.read() | ||
with open(yaml_path, encoding="utf-8") as f: | ||
all_functions = yaml.safe_load(f) | ||
valid_tags = set() | ||
for func in all_functions: | ||
if "tags" not in func: | ||
continue | ||
valid_tags.add(func["tags"]) | ||
|
||
return yaml_str, valid_tags | ||
|
||
|
||
def parse_native_functions_yaml(yaml_path: str) -> tuple[Any, Any]: | ||
yaml_str, valid_tags = load_native_function_yaml(yaml_path) | ||
yaml_struct = yaml.load(yaml_str, Loader=torchgen.gen.LineLoader) | ||
parsed = torchgen.gen.parse_native_yaml_struct( | ||
yaml_struct, valid_tags, path=yaml_path, skip_native_fns_gen=True | ||
) | ||
return parsed.native_functions, parsed.backend_indices | ||
|
||
|
||
def get_argument_type(arg: torchgen.model.Argument) -> cg.TypeRef: | ||
# TODO: Handel scalar type | ||
optional = arg.type.is_nullable() | ||
if arg.type.is_base_ty_like(torchgen.model.BaseTy.Tensor): | ||
inner_type = cg.TypeRef(None, "Tensor") | ||
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.SymInt): | ||
# TODO(justinchuby): Make sure this is a scalar | ||
inner_type = cg.TypeRef(None, "INT64") | ||
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.float): | ||
inner_type = cg.TypeRef(None, "float") | ||
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.int): | ||
inner_type = cg.TypeRef(None, "int") | ||
elif arg.type.is_base_ty_like(torchgen.model.BaseTy.bool): | ||
inner_type = cg.TypeRef(None, "bool") | ||
else: | ||
inner_type = cg.TypeRef(None, "Any") | ||
|
||
if optional: | ||
return cg.TypeRef(None, "Optional", inner_type) | ||
return inner_type | ||
|
||
|
||
def should_generate_signature(func: torchgen.model.NativeFunction) -> bool: | ||
"""Returns whether the signature for the given function should be generated.""" | ||
if func.func.name.name.base.startswith("_"): | ||
return False | ||
if func.func.name.overload_name: | ||
# Ignore overloads for now. | ||
return False | ||
return True | ||
|
||
|
||
def get_op_name(func: torchgen.model.NativeFunction) -> str: | ||
if func.func.name.overload_name: | ||
return f"{func.func.name.name.base}_{func.func.name.overload_name}" | ||
|
||
return func.func.name.name.base | ||
|
||
|
||
|
||
def create_signature(func: torchgen.model.NativeFunction) -> Any: | ||
"""Creates the signature for the given function.""" | ||
print(func.namespace) | ||
print(func.python_module) | ||
print(func.func.name.overload_name) | ||
|
||
op_name = get_op_name(func) | ||
args = [ | ||
arg.argument if isinstance(arg, torchgen.model.SelfArgument) else arg | ||
for arg in func.func.arguments.positional | ||
] | ||
kwargs = [ | ||
arg | ||
for arg in func.func.arguments.kwarg_only | ||
if not isinstance(arg, torchgen.model.TensorOptionsArguments) | ||
] | ||
|
||
py_args = [ | ||
cg.Arg( | ||
arg.name, | ||
get_argument_type(arg), | ||
default_value=cg.ThunkExpr(arg.default) if arg.default is not None else None, | ||
) | ||
for arg in args | ||
] | ||
if kwargs: | ||
py_args += [ | ||
# Arguments after this point are keyword-only. | ||
cg.Arg( | ||
"*", | ||
) | ||
] + [ | ||
cg.Arg( | ||
kwarg.name, | ||
get_argument_type(kwarg), | ||
default_value=cg.ThunkExpr(kwarg.default or "None"), | ||
is_kwarg=True, | ||
) | ||
for kwarg in kwargs | ||
] | ||
|
||
return cg.FunctionDef( | ||
op_name, | ||
*py_args, | ||
return_type=None, # TODO: Add return type | ||
body=[ | ||
cg.Raise( | ||
cg.ThunkExpr("NotImplementedError"), | ||
) | ||
], | ||
) | ||
|
||
|
||
def main(args: argparse.Namespace) -> None: | ||
native_functions, _ = parse_native_functions_yaml(args.native_functions_yaml) | ||
|
||
for func in native_functions: | ||
if not should_generate_signature(func): | ||
continue | ||
|
||
py_tree = create_signature(func) | ||
py_tree.accept(cg.PythonWriter(sys.stdout)) | ||
print() | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--native-functions-yaml", | ||
type=str, | ||
default="/home/justinchu/dev/pytorch/aten/src/ATen/native/native_functions.yaml", | ||
) | ||
main(parser.parse_args()) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.