|
| 1 | +# -------------------------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. |
| 4 | +# -------------------------------------------------------------------------- |
| 5 | + |
1 | 6 | """Generates the ATen signatures for the ONNX ATen operator set using torch.ops.""" |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +import argparse |
| 10 | +import ast |
| 11 | +import logging |
| 12 | +import os |
| 13 | +import textwrap |
| 14 | +import typing |
| 15 | +from typing import Any, Sequence |
| 16 | + |
| 17 | +import torchgen.gen |
| 18 | +import torchgen.model |
| 19 | +import yaml |
| 20 | + |
| 21 | +import opgen.pygen as cg |
| 22 | + |
| 23 | + |
| 24 | +def load_native_function_yaml(yaml_path: str): |
| 25 | + with open(yaml_path, encoding="utf-8") as f: |
| 26 | + yaml_str = f.read() |
| 27 | + with open(yaml_path, encoding="utf-8") as f: |
| 28 | + all_functions = yaml.safe_load(f) |
| 29 | + valid_tags = set() |
| 30 | + # Mark all tags as valid, since we don't want to validate them. |
| 31 | + for func in all_functions: |
| 32 | + if "tags" not in func: |
| 33 | + continue |
| 34 | + valid_tags.add(func["tags"]) |
| 35 | + |
| 36 | + return yaml_str, valid_tags |
| 37 | + |
| 38 | + |
| 39 | +def parse_native_functions_yaml(yaml_path: str) -> tuple[Any, Any]: |
| 40 | + """Parses the native_functions.yaml file.""" |
| 41 | + yaml_str, valid_tags = load_native_function_yaml(yaml_path) |
| 42 | + yaml_struct = yaml.load(yaml_str, Loader=torchgen.gen.LineLoader) # noqa: DUO109 |
| 43 | + parsed = torchgen.gen.parse_native_yaml_struct( |
| 44 | + yaml_struct, valid_tags, path=yaml_path, skip_native_fns_gen=True |
| 45 | + ) |
| 46 | + return parsed.native_functions, parsed.backend_indices |
| 47 | + |
| 48 | + |
| 49 | +def create_list_type(arg: torchgen.model.Argument) -> cg.TypeRef: |
| 50 | + assert isinstance(arg.type, torchgen.model.ListType), f"arg: {arg}" |
| 51 | + arg_type = arg_type_to_str(arg.type) |
| 52 | + if type_is_builtin(arg_type): |
| 53 | + return cg.TypingRefs.Sequence(cg.BuiltinTypeRef(arg_type)) |
| 54 | + if arg_type == "TensorType": |
| 55 | + return cg.TypingRefs.Sequence(cg.TypeRef("onnxscript.onnx_types", "TensorType")) |
| 56 | + return cg.TypeRef("onnxscript", arg_type) |
| 57 | + |
| 58 | + # TODO(justinchuby): Enable this when generics are better supported |
| 59 | + # if arg.type.size is None: |
| 60 | + # # INT64[...] |
| 61 | + # return cg.TypeRef("onnxscript", arg_type, cg.EllipsisTypeRef()) |
| 62 | + # # INT64[3] |
| 63 | + # return cg.TypeRef("onnxscript", arg_type, *[cg.TypeRef(None, f"{arg.type.size}")]) |
| 64 | + |
| 65 | + |
| 66 | +def arg_type_to_str(arg_type: torchgen.model.Type) -> str: |
| 67 | + if arg_type.is_base_ty_like(torchgen.model.BaseTy.Tensor): |
| 68 | + return "TensorType" |
| 69 | + elif arg_type.is_base_ty_like(torchgen.model.BaseTy.SymInt): |
| 70 | + return "INT64" |
| 71 | + elif arg_type.is_base_ty_like(torchgen.model.BaseTy.Scalar): |
| 72 | + return "float" |
| 73 | + elif arg_type.is_base_ty_like(torchgen.model.BaseTy.float): |
| 74 | + return "float" |
| 75 | + elif arg_type.is_base_ty_like(torchgen.model.BaseTy.int): |
| 76 | + return "int" |
| 77 | + elif arg_type.is_base_ty_like(torchgen.model.BaseTy.bool): |
| 78 | + return "bool" |
| 79 | + elif arg_type.is_base_ty_like(torchgen.model.BaseTy.str): |
| 80 | + return "str" |
| 81 | + elif arg_type.is_base_ty_like(torchgen.model.BaseTy.ScalarType): |
| 82 | + return "int" |
| 83 | + else: |
| 84 | + # Anything unhandled is a string option. |
| 85 | + return "str" |
| 86 | + |
| 87 | + |
| 88 | +def type_is_builtin(arg_type: str) -> bool: |
| 89 | + """Returns whether the given type is a python builtin type (that we care about).""" |
| 90 | + return arg_type in {"float", "int", "bool", "str"} |
| 91 | + |
| 92 | + |
| 93 | +def get_argument_type(arg: torchgen.model.Argument) -> cg.TypeRef: |
| 94 | + """Returns the Python type for the given argument.""" |
| 95 | + if isinstance(arg.type, torchgen.model.ListType): |
| 96 | + inner_node = create_list_type(arg) |
| 97 | + else: |
| 98 | + arg_type_str = arg_type_to_str(arg.type) |
| 99 | + if type_is_builtin(arg_type_str): |
| 100 | + inner_node = cg.BuiltinTypeRef(arg_type_str) |
| 101 | + elif arg_type_str == "TensorType": |
| 102 | + inner_node = cg.TypeRef("onnxscript.onnx_types", "TensorType") |
| 103 | + else: |
| 104 | + inner_node = cg.TypeRef("onnxscript", arg_type_str) |
| 105 | + |
| 106 | + if arg.type.is_nullable(): |
| 107 | + return cg.TypingRefs.Optional(inner_node) |
| 108 | + if arg.default is not None and parse_default_value(arg) is None: |
| 109 | + return cg.TypingRefs.Optional(inner_node) |
| 110 | + return inner_node |
| 111 | + |
| 112 | + |
| 113 | +def should_generate_signature(func: torchgen.model.NativeFunction) -> bool: |
| 114 | + """Returns whether the signature for the given function should be generated.""" |
| 115 | + if func.func.name.name.base.startswith("_"): |
| 116 | + return False |
| 117 | + if func.func.name.name.inplace: |
| 118 | + return False |
| 119 | + if func.func.name.overload_name and func.func.name.overload_name != "Tensor": |
| 120 | + # Ignore overloads for now. |
| 121 | + # Some ops only have overloaded versions, like aten::add.Tensor. And we |
| 122 | + # want to generate the aten::add op. |
| 123 | + return False |
| 124 | + return True |
| 125 | + |
| 126 | + |
| 127 | +def get_op_name(func: torchgen.model.NativeFunction) -> str: |
| 128 | + if func.func.name.overload_name and func.func.name.overload_name != "Tensor": |
| 129 | + # Do not include the overload name if it is "Tensor", since ops like |
| 130 | + # aten::add.Tensor is what we want for aten::add. |
| 131 | + name = f"{func.func.name.name.base}_{func.func.name.overload_name}" |
| 132 | + else: |
| 133 | + name = f"{func.func.name.name.base}" |
| 134 | + |
| 135 | + # Prefix with namespace to avoid name conflicts with other operators and arguments. |
| 136 | + return f"{func.namespace}_{name}" |
| 137 | + |
| 138 | + |
| 139 | +def parse_default_value(arg: torchgen.model.Argument) -> Any: |
| 140 | + default = arg.default |
| 141 | + assert default is not None, f"arg: {arg}" |
| 142 | + if default.startswith("[") and default.endswith("]"): |
| 143 | + # Convert list to tuple |
| 144 | + default_val = ast.literal_eval(default) |
| 145 | + assert isinstance(default_val, list) |
| 146 | + if not default_val: |
| 147 | + # Empty list is represented as None. |
| 148 | + return None |
| 149 | + return tuple(default_val) |
| 150 | + # Special case for reduction=Mean |
| 151 | + if default == "Mean": |
| 152 | + return 1 |
| 153 | + |
| 154 | + try: |
| 155 | + value = ast.literal_eval(default) |
| 156 | + if isinstance(value, int): |
| 157 | + # Expand the value to a tuple if the type is a list. |
| 158 | + if isinstance(arg.type, torchgen.model.ListType): |
| 159 | + if arg.type.size is not None: |
| 160 | + return (value,) * arg.type.size |
| 161 | + return (value,) |
| 162 | + return value |
| 163 | + except ValueError: |
| 164 | + # Treat it as a string. |
| 165 | + return default.lower() |
| 166 | + |
| 167 | + |
| 168 | +def create_return_type(returns: Sequence[torchgen.model.Return]) -> cg.TypeRef: |
| 169 | + """Returns the Python type for the return value of the given function.""" |
| 170 | + if not returns: |
| 171 | + return cg.TypingRefs.Any() |
| 172 | + return_nodes = [] |
| 173 | + for return_val in returns: |
| 174 | + return_type = return_val.type |
| 175 | + return_type_str = arg_type_to_str(return_type) |
| 176 | + if type_is_builtin(return_type_str): |
| 177 | + # Python type |
| 178 | + return_node: cg.TypeRef = cg.BuiltinTypeRef(return_type_str) |
| 179 | + elif return_type_str == "TensorType": |
| 180 | + return_node = cg.TypeRef("onnxscript.onnx_types", "TensorType") |
| 181 | + else: |
| 182 | + return_node = cg.TypeRef("onnxscript", arg_type_to_str(return_type)) |
| 183 | + if return_type.is_nullable(): |
| 184 | + return_node = cg.TypingRefs.Optional(return_node) |
| 185 | + return_nodes.append(return_node) |
| 186 | + if len(return_nodes) == 1: |
| 187 | + return return_nodes[0] |
| 188 | + return cg.BuiltinTypeRef("tuple", *return_nodes) |
| 189 | + |
| 190 | + |
| 191 | +def format_arg_name(arg: torchgen.model.Argument) -> str: |
| 192 | + """Returns the python compatible name of the given argument.""" |
| 193 | + if arg.name == "from": |
| 194 | + return f"{arg.name}_" |
| 195 | + return arg.name # type: ignore[no-any-return] |
| 196 | + |
| 197 | + |
| 198 | +def create_signature(func: torchgen.model.NativeFunction) -> cg.FunctionDef: |
| 199 | + """Creates the signature for the given function.""" |
| 200 | + op_name = get_op_name(func) |
| 201 | + args = [ |
| 202 | + arg.argument if isinstance(arg, torchgen.model.SelfArgument) else arg |
| 203 | + for arg in func.func.arguments.positional |
| 204 | + ] |
| 205 | + kwargs = [ |
| 206 | + arg |
| 207 | + for arg in func.func.arguments.kwarg_only |
| 208 | + if not isinstance(arg, torchgen.model.TensorOptionsArguments) |
| 209 | + ] |
| 210 | + |
| 211 | + py_args = [ |
| 212 | + cg.Arg( |
| 213 | + format_arg_name(arg), |
| 214 | + get_argument_type(arg), |
| 215 | + default_value=cg.Constant(parse_default_value(arg)) |
| 216 | + if arg.default is not None |
| 217 | + else None, |
| 218 | + ) |
| 219 | + for arg in args |
| 220 | + ] |
| 221 | + if kwargs: |
| 222 | + # Arguments after this point are keyword-only. |
| 223 | + py_args += [ |
| 224 | + cg.Arg( |
| 225 | + format_arg_name(kwarg), |
| 226 | + get_argument_type(kwarg), |
| 227 | + default_value=cg.Constant(parse_default_value(kwarg)) |
| 228 | + if kwarg.default is not None |
| 229 | + else None, |
| 230 | + is_kwarg=True, |
| 231 | + ) |
| 232 | + for kwarg in kwargs |
| 233 | + ] |
| 234 | + |
| 235 | + return cg.FunctionDef( |
| 236 | + op_name, |
| 237 | + *py_args, |
| 238 | + return_type=create_return_type(func.func.returns), |
| 239 | + body=[ |
| 240 | + cg.ThunkStmt(f"# {func.func}"), |
| 241 | + cg.Raise(cg.Call(cg.Name("NotImplementedError"))), # type: ignore[list-item] |
| 242 | + ], |
| 243 | + ) |
| 244 | + |
| 245 | + |
| 246 | +def create_onnx_function_module( |
| 247 | + functions: Sequence[torchgen.model.NativeFunction], |
| 248 | +) -> cg.Module: |
| 249 | + """Creates the onnx function module.""" |
| 250 | + return cg.Module( |
| 251 | + cg.ImportFrom("__future__", cg.Alias("annotations")), |
| 252 | + *[create_signature(func) for func in functions if should_generate_signature(func)], |
| 253 | + ) |
| 254 | + |
| 255 | + |
| 256 | +def copyright_header() -> str: |
| 257 | + """Creates the copyright header.""" |
| 258 | + dashline = f"# {'-' * 74}" |
| 259 | + return textwrap.dedent( |
| 260 | + f"""\ |
| 261 | + {dashline} |
| 262 | + # Copyright (c) Microsoft Corporation. All rights reserved. |
| 263 | + # Licensed under the MIT License. |
| 264 | + {dashline} |
| 265 | + # mypy: disable-error-code=misc |
| 266 | + # mypy: disable-error-code=type-arg |
| 267 | + # mypy: disable-error-code=valid-type |
| 268 | + # mypy: disable-error-code=assignment |
| 269 | + """ |
| 270 | + ) |
| 271 | + |
| 272 | + |
| 273 | +def main(args: argparse.Namespace) -> None: |
| 274 | + native_functions, _ = parse_native_functions_yaml(args.yaml) |
| 275 | + functions: dict[str, dict[str, torchgen.model.NativeFunction]] = {} |
| 276 | + for func in native_functions: |
| 277 | + if not should_generate_signature(func): |
| 278 | + continue |
| 279 | + |
| 280 | + module_name = typing.cast(str, func.python_module) |
| 281 | + if not module_name: |
| 282 | + module_name = "core" |
| 283 | + if module_name not in functions: |
| 284 | + functions[module_name] = {} |
| 285 | + op_name = get_op_name(func) |
| 286 | + if op_name in functions[module_name]: |
| 287 | + logging.warning( |
| 288 | + "Duplicated function: %s, overload: %s", op_name, func.func.name.overload_name |
| 289 | + ) |
| 290 | + continue |
| 291 | + functions[module_name][op_name] = func |
| 292 | + |
| 293 | + os.makedirs(args.outdir, exist_ok=True) |
| 294 | + |
| 295 | + for module_name, module_functions in functions.items(): |
| 296 | + sorted_functions = sorted(module_functions.items(), key=lambda x: x[0]) |
| 297 | + py_module = create_onnx_function_module([func for _, func in sorted_functions]) |
| 298 | + py_module.accept(cg.ImportAdjuster()) |
| 299 | + py_module.accept(cg.DocCommentBuilder()) |
| 300 | + output_path = os.path.join(args.outdir, f"{module_name}.py") |
| 301 | + |
| 302 | + print(f"Generating {output_path}") |
| 303 | + with open(output_path, "w", encoding="utf-8") as f: |
| 304 | + f.write(copyright_header()) |
| 305 | + # Add docstring |
| 306 | + f.write( |
| 307 | + textwrap.dedent( |
| 308 | + f'''\ |
| 309 | + """torch.ops.aten operators under the `{module_name}` module. |
| 310 | +
|
| 311 | + - No inplace operators. |
| 312 | + - All functions should not have the script() decorator. This is because |
| 313 | + we want to delay the compilation of the function. |
| 314 | + """ |
| 315 | + ''' |
| 316 | + ) |
| 317 | + ) |
| 318 | + py_module.accept(cg.PythonWriter(f)) |
| 319 | + print("Done.") |
2 | 320 |
|
3 | 321 |
|
4 | | -def main(): |
5 | | - pass |
| 322 | +if __name__ == "__main__": |
| 323 | + parser = argparse.ArgumentParser() |
| 324 | + parser.add_argument( |
| 325 | + "--yaml", |
| 326 | + type=str, |
| 327 | + help="Path to PyTorch aten/src/ATen/native/native_functions.yaml", |
| 328 | + ) |
| 329 | + parser.add_argument( |
| 330 | + "--outdir", |
| 331 | + type=str, |
| 332 | + help="Output directory for generated modules", |
| 333 | + ) |
| 334 | + main(parser.parse_args()) |
0 commit comments