diff --git a/onnxscript/_internal/ast_utils.py b/onnxscript/_internal/ast_utils.py new file mode 100644 index 0000000000..bf73d8b88b --- /dev/null +++ b/onnxscript/_internal/ast_utils.py @@ -0,0 +1,24 @@ +"""Utilities for working with Python ASTs.""" +from __future__ import annotations + +import ast +import inspect +import textwrap +import types + + +def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]: + try: + src = inspect.getsource(f) + except OSError as e: + raise RuntimeError( + f"Decorator script does not work on dynamically " + f"compiled function {f.__name__}." + ) from e + src = textwrap.dedent(src) + top_level_ast = ast.parse(src) + assert isinstance(top_level_ast, ast.Module) + assert len(top_level_ast.body) == 1 + f_ast = top_level_ast.body[0] + assert isinstance(f_ast, ast.FunctionDef) + return src, f_ast diff --git a/onnxscript/analysis_test.py b/onnxscript/analysis_test.py index 521ea304f1..362d44dab9 100644 --- a/onnxscript/analysis_test.py +++ b/onnxscript/analysis_test.py @@ -4,7 +4,8 @@ import unittest from typing import Any -from onnxscript import analysis, main +from onnxscript import analysis +from onnxscript._internal import ast_utils from onnxscript.onnx_opset import opset15 as op from onnxscript.sourceinfo import formatter @@ -27,7 +28,7 @@ def generic_visit(self, node): class TestLivenessAnalysis(unittest.TestCase): def analyze(self, fun): - source, parse_tree = main.get_src_and_ast(fun) + source, parse_tree = ast_utils.get_src_and_ast(fun) analysis.do_liveness_analysis(parse_tree, formatter(source)) visitor = AnalysisResultsVisitor() visitor.visit(parse_tree) @@ -97,7 +98,7 @@ def while_eg(x): class TestExposedUses(unittest.TestCase): def assertUses(self, f, expected): - source, parse_tree = main.get_src_and_ast(f) + source, parse_tree = ast_utils.get_src_and_ast(f) result = analysis.exposed_uses(parse_tree.body, formatter(source)) self.assertEqual(result, set(expected)) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index d3fc3d4c12..e7f8f32dfd 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -67,7 +67,7 @@ def torch_op( registry: Optional[Registry] = None, trace_only: bool = False, private: bool = False, -) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]: +) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. Args: @@ -81,12 +81,16 @@ def torch_op( if registry is None: registry = default_registry - def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]: + def wrapper( + func: FunctionType, + ) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction: + # Compile the function + custom_opset = onnxscript.values.Opset(domain="onnxscript.atenlib", version=1) + + processed_func: onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction if trace_only: - processed_func = func + processed_func = onnxscript.values.TracedOnnxFunction(custom_opset, func) else: - # Compile the function - custom_opset = onnxscript.values.Opset(domain="onnxscript.atenlib", version=1) assert isinstance(func, FunctionType) processed_func = onnxscript.script(opset=custom_opset)(func) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 443f18412e..98cd6a462b 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -202,7 +202,7 @@ def __str__(self): args = _format(self.args, "(", ", ", ")", _opt_var_to_str) domain = self.callee.opset.domain - opname = self.callee.opname + opname = self.callee.name callee = f"{domain}.{opname}" if (domain != "") else opname return f"{lhs} = {callee} {attrs}{args}" @@ -212,7 +212,7 @@ def debug_print(self): def to_node_proto(self, node_name: str) -> onnx.NodeProto: n = helper.make_node( - self.callee.opname, + self.callee.name, [_opt_var_to_str(x) for x in self.args], [str(x) for x in self.result], domain=self.callee.opset.domain, diff --git a/onnxscript/main.py b/onnxscript/main.py index 4f5431fe63..2dddb54d00 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -8,7 +8,6 @@ import ast import inspect import sys -import textwrap import types from typing import Any, Callable, Optional, Sequence, cast @@ -16,28 +15,7 @@ import onnxscript from onnxscript import converter, irbuilder, values - - -def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]: - try: - src = inspect.getsource(f) - except OSError as e: - raise RuntimeError( - f"Decorator script does not work on dynamically " - f"compiled function {f.__name__}." - ) from e - src = textwrap.dedent(src) - top_level_ast = ast.parse(src) - assert isinstance(top_level_ast, ast.Module) - assert len(top_level_ast.body) == 1 - f_ast = top_level_ast.body[0] - assert isinstance(f_ast, ast.FunctionDef) - return src, f_ast - - -def get_ast(f: types.FunctionType) -> ast.FunctionDef: - _, f_ast = get_src_and_ast(f) - return f_ast +from onnxscript._internal import ast_utils def script_check( @@ -104,7 +82,7 @@ def transform(f: types.FunctionType) -> onnxscript.OnnxFunction: if not inspect.isfunction(f): raise TypeError("The ONNXScript decorator should be applied to functions only.") - src, f_ast = get_src_and_ast(f) # pylint: disable=redefined-outer-name + src, f_ast = ast_utils.get_src_and_ast(f) # The script should be compiled using the globals/locals at the definition site. # This allows the script to reference names defined outside the script, # which is used for a few different purposes. diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index df3da0a0e8..ee10037484 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -119,6 +119,19 @@ def test_script_function_has_op_schema(self, _, func_with_wrangler): self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) + @parameterized.parameterized.expand( + list(ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.items()) + ) + @unittest.skipIf( + version_utils.onnx_older_than("1.15"), + "OpSchema is not writable before ONNX 1.15", + ) + def test_trace_only_function_has_op_schema(self, _, func_with_wrangler): + func, _ = _split_function_and_wrangler(func_with_wrangler) + schema = func.opschema + self.assertIsNotNone(schema) + self.assertEqual(schema.name, func.name) + def run_test_output_match( test_suite: unittest.TestCase, diff --git a/onnxscript/values.py b/onnxscript/values.py index c010f77f31..7d6ac204a0 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -5,17 +5,20 @@ from __future__ import annotations import dataclasses +import inspect import logging import types +import typing from enum import IntFlag from typing import _GenericAlias # type: ignore[attr-defined] -from typing import Any, Optional, Sequence +from typing import Any, Optional, Protocol, Sequence import onnx import onnx.defs +from onnxscript import converter as converter_module from onnxscript import irbuilder, sourceinfo, type_annotation -from onnxscript._internal import version_utils +from onnxscript._internal import ast_utils, version_utils _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, @@ -50,6 +53,8 @@ class Opset: Only a single instance of Opset is created for a given (domain, version) pair. """ + domain: str + version: int cache: dict[tuple[type, str, int], Opset] = {} def __new__(cls, domain: str, version: int): @@ -161,21 +166,106 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: return onnx.helper.get_attribute_value(attr_proto) -class Op: +def param_schemas_from_op_schema( + op_schema: onnx.defs.OpSchema, +) -> tuple[ParamSchema, ...]: + """Get the parameter schemas from an ONNX OpSchema.""" + schemas = [] + for input_ in op_schema.inputs: + param_schema = ParamSchema( + name=input_.name, + is_input=True, + required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional), + is_variadic_input=( + input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic + ), + ) + schemas.append(param_schema) + for attr_name, attribute in op_schema.attributes.items(): + default_attr_proto = attribute.default_value + param_schema = ParamSchema( + name=attr_name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type], + default=_get_attribute_value(default_attr_proto), + is_input=False, + required=attribute.required, + ) + schemas.append(param_schema) + + return tuple(schemas) + + +def param_schemas_from_function_ir( + function_ir: irbuilder.IRFunction, +) -> tuple[ParamSchema, ...]: + """Get the parameter schemas from a FunctionIR.""" + # The first len(func_ir.inputs) arguments are onnx inputs + # The rest is onnx attributes + + schemas = [] + for arg in function_ir.inputs: + if isinstance(arg.typeinfo, onnx.TypeProto.Optional): + required = False + else: + required = True + schemas.append( + ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required) + ) + + for attr_parameter in function_ir.attrs: + schemas.append( + ParamSchema( + name=attr_parameter.name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( + onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg] + ), + default=_EmptyDefault + if attr_parameter.default_value is None + else attr_parameter.default_value, + is_input=False, + required=not attr_parameter.has_default, + ) + ) + + return tuple(schemas) + + +@typing.runtime_checkable +class OpLike(Protocol): + """A protocol for objects that have an ONNX OpSchema.""" + + @property + def name(self) -> str: + ... + + @property + def opset(self) -> Opset: + ... + + @property + def opschema(self) -> Optional[onnx.defs.OpSchema]: + ... + + def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: + ... + + +class Op(OpLike): """Represents an ONNX op instance (for example, the MatMul op from ONNX opset version 13). + It belongs to a particular Opset and has a name. Attributes: opset: The Opset that this op belongs to. - opname: The name of the op. + name: The name of the op. opschema: The ONNX OpSchema for the op. """ def __init__( - self, opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None + self, opset: Opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None ) -> None: - self.opset = opset - self.opname = opname + self._opset = opset + self._name = opname self._opschema = opschema self._param_schemas: Optional[tuple[ParamSchema, ...]] = None @@ -186,12 +276,20 @@ def __call__(self, *args, **kwargs): schema = self.get_schema() if schema is None: raise RuntimeError( - f"Op '{self.opname}' does not have an OpSchema and cannot be evaluated." + f"Op '{self.name}' does not have an OpSchema and cannot be evaluated." ) return evaluator.default().eval(schema, args, kwargs) def is_single_op(self) -> bool: - return isinstance(self.opname, str) + return isinstance(self.name, str) + + @property + def name(self) -> str: + return self._name + + @property + def opset(self) -> Opset: + return self._opset @property def opschema(self) -> Optional[onnx.defs.OpSchema]: @@ -201,7 +299,7 @@ def get_schema(self) -> Optional[onnx.defs.OpSchema]: """Returns the ONNX OpSchema for this op.""" if self.opschema is not None: return self.opschema - return self.opset[self.opname] + return self.opset[self.name] def has_schema(self) -> bool: """Returns True if this op has an OpSchema.""" @@ -215,30 +313,9 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: op_schema = self.get_schema() if op_schema is None: return None - schemas = [] - for input_ in op_schema.inputs: - param_schema = ParamSchema( - name=input_.name, - is_input=True, - required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional), - is_variadic_input=( - input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic - ), - ) - schemas.append(param_schema) - for attr_name, attribute in op_schema.attributes.items(): - default_attr_proto = attribute.default_value - param_schema = ParamSchema( - name=attr_name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type], - default=_get_attribute_value(default_attr_proto), - is_input=False, - required=attribute.required, - ) - schemas.append(param_schema) - self._param_schemas = tuple(schemas) - return self._param_schemas # type: ignore[return-value] + self._param_schemas = param_schemas_from_op_schema(op_schema) + return self._param_schemas @dataclasses.dataclass(repr=False, eq=False) @@ -333,7 +410,7 @@ def op_schema_from_function_ir( *[ onnx.defs.OpSchema.Attribute( attr.name, - type=onnx.defs.OpSchema.AttrType(attr.type), + type=onnx.defs.OpSchema.AttrType(attr.type), # type: ignore[call-arg] ) for attr in function_ir.attrs if not attr.has_default @@ -353,23 +430,34 @@ def op_schema_from_function_ir( class OnnxFunction(Op): """Represents an ONNX op for which a function-body has been defined in onnxscript. - Args: - opset: opset the function belongs to - pyfun: python function - irfun: python code parsed by class - :class:`onnxscript.converter.Converter` - source: source code used to generate the function - kwargs: additional properties used to construct a ModelProto + Attributes: + opset: Opset the function belongs to. + name: Name of the function. + function: Python function. + function_ir: Python code parsed as an :class:`irbuilder.IRFunction`. + source: Source code used to generate the function. + kwargs: Additional properties used to construct a ModelProto. + opschema: Generated ONNX OpSchema for this op. """ def __init__( self, - opset: Opset, + opset: Optional[Opset], pyfun: types.FunctionType, irfun: irbuilder.IRFunction, source: str, kwargs: dict[str, Any], ): + """Constructs an OnnxFunction. + + Args: + opset: opset the function belongs to + pyfun: python function + irfun: python code parsed by class + :class:`onnxscript.converter.Converter` + source: source code used to generate the function + kwargs: additional properties used to construct a ModelProto + """ opset = opset or Opset(irfun.domain, 1) super().__init__(opset, irfun.name) self.function = pyfun @@ -379,11 +467,6 @@ def __init__( self._param_schemas: Optional[tuple[ParamSchema, ...]] = None self._opschema: Optional[onnx.defs.OpSchema] = None - @property - def name(self): - """Returns the function name.""" - return self.opname - @property def opschema(self) -> Optional[onnx.defs.OpSchema]: """Construct an OpSchema from function_ir.""" @@ -429,38 +512,8 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - function_ir = self.function_ir - # The first len(func_ir.inputs) arguments are onnx inputs - inputs = function_ir.inputs - # The rest is onnx attributes - - schemas = [] - for arg in inputs: - if isinstance(arg.typeinfo, onnx.TypeProto.Optional): - required = False - else: - required = True - schemas.append( - ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required) - ) - - for attr_parameter in function_ir.attrs: - schemas.append( - ParamSchema( - name=attr_parameter.name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( - onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg] - ), - default=_EmptyDefault - if attr_parameter.default_value is None - else attr_parameter.default_value, - is_input=False, - required=not attr_parameter.has_default, - ) - ) - - self._param_schemas = tuple(schemas) - return self._param_schemas # type: ignore[return-value] + self._param_schemas = param_schemas_from_function_ir(self.function_ir) + return self._param_schemas def to_function_proto(self): """Converts the function into :class:`onnx.FunctionProto`.""" @@ -484,6 +537,76 @@ def to_model_proto(self, **kwargs): return self.function_ir.to_model_proto(**merged_kw_args) +class TracedOnnxFunction(Op): + """TracedOnnxFunction. + + Attributes: + name: Name of the op. E.g. "aten::add". + func: Function. + """ + + def __init__(self, opset: Opset, func: types.FunctionType): + super().__init__(opset, func.__name__) + self.func = func + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def __repr__(self): + return f"{self.__class__.__name__}({self.func!r})" + + @property + def name(self) -> str: + """Return the name of the op.""" + return self.func.__name__ + + @property + def function_ir(self) -> irbuilder.IRFunction: + """Return the function_ir. + + This function IR contains only the signature of the function. + """ + src, func_ast = ast_utils.get_src_and_ast(self.func) + module = inspect.getmodule(self.func) + closure = inspect.getclosurevars(self.func) + global_names = module.__dict__.copy() + global_names.update(closure.nonlocals) + converter = converter_module.Converter( + opset=self._opset, + global_names=global_names, + source=src, + ) + + return converter.translate_function_signature(func_ast) + + @property + def opschema(self) -> Optional[onnx.defs.OpSchema]: + """Return the opschema.""" + + if self._opschema is not None: + return self._opschema + + if not _ONNX_OP_SCHEMA_WRITABLE: + return None + + # FIXME(justinchuby): outputs are empty. Need to fix. + self._opschema = op_schema_from_function_ir(self.function_ir, self._opset) + + return self._opschema + + def param_schemas(self) -> tuple[ParamSchema, ...]: + """Returns the parameter schemas of this function.""" + if self._param_schemas is not None: + return self._param_schemas + + # NOTE: We generate the parameter schemas from the function_ir instead + # of relying on the auto generated OpSchema because we need to preserve the keyword + # argument order from the Python function definition, which is lost in OpSchema. + # FIXME(justinchuby): Fix param ordering when attributes come before inputs. + self._param_schemas = param_schemas_from_function_ir(self.function_ir) + return self._param_schemas + + class SymbolValue: """Represents script-time value information about named variables used in a script. diff --git a/onnxscript/values_test.py b/onnxscript/values_test.py new file mode 100644 index 0000000000..978d7ad2fb --- /dev/null +++ b/onnxscript/values_test.py @@ -0,0 +1,15 @@ +import unittest + +from onnxscript import values + + +class TracedOnnxFunctionTest(unittest.TestCase): + def test_init(self): + def function(input1, input2, attr1: int, attr2: int = 1): + return input1 + input2 + attr1 + attr2 + + opset = values.Opset("test", 1) + traced_function = values.TracedOnnxFunction(opset, function) + self.assertEqual(traced_function.opset, opset) + self.assertEqual(traced_function.name, function.__name__) + self.assertEqual(traced_function.func, function)