From c6505cf19829d264c37873a123404cc22bac401b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 17:30:24 +0000 Subject: [PATCH 01/22] Move version_utils to _internal | chore Move version_utils to `_internal` so that it can be used my onnxscript [ghstack-poisoned] --- onnxscript/{tests/common => _internal}/version_utils.py | 0 onnxscript/function_libs/torch_aten/graph_building_test.py | 2 +- .../tests/function_libs/torch_aten/ops_correctness_test.py | 2 +- onnxscript/tests/functions/onnxfns1A_test.py | 3 ++- onnxscript/tests/functions/onnxfns2_test.py | 3 ++- onnxscript/tests/functions/onnxfns_test.py | 3 ++- 6 files changed, 8 insertions(+), 5 deletions(-) rename onnxscript/{tests/common => _internal}/version_utils.py (100%) diff --git a/onnxscript/tests/common/version_utils.py b/onnxscript/_internal/version_utils.py similarity index 100% rename from onnxscript/tests/common/version_utils.py rename to onnxscript/_internal/version_utils.py diff --git a/onnxscript/function_libs/torch_aten/graph_building_test.py b/onnxscript/function_libs/torch_aten/graph_building_test.py index 0257fa3315..60646b8768 100644 --- a/onnxscript/function_libs/torch_aten/graph_building_test.py +++ b/onnxscript/function_libs/torch_aten/graph_building_test.py @@ -10,8 +10,8 @@ import onnxscript.testing from onnxscript import FLOAT, evaluator from onnxscript import opset17 as op +from onnxscript._internal import version_utils from onnxscript.function_libs.torch_aten import graph_building, ops -from onnxscript.tests.common import version_utils @unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported") diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 011ba96a43..897a8cdcdb 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -58,11 +58,11 @@ import onnxscript import onnxscript.evaluator +from onnxscript._internal import version_utils from onnxscript.function_libs.torch_aten import graph_building from onnxscript.function_libs.torch_aten.ops import core as core_ops from onnxscript.function_libs.torch_aten.ops import nn as nn_ops from onnxscript.function_libs.torch_aten.ops import special as special_ops -from onnxscript.tests.common import version_utils from onnxscript.tests.function_libs.torch_aten import extra_opinfo T = TypeVar("T") diff --git a/onnxscript/tests/functions/onnxfns1A_test.py b/onnxscript/tests/functions/onnxfns1A_test.py index 00148e5632..a9c22aba43 100644 --- a/onnxscript/tests/functions/onnxfns1A_test.py +++ b/onnxscript/tests/functions/onnxfns1A_test.py @@ -3,7 +3,8 @@ import onnx import pytest -from onnxscript.tests.common import onnx_script_test_case, version_utils +from onnxscript._internal import version_utils +from onnxscript.tests.common import onnx_script_test_case from onnxscript.tests.models import onnxfns1A diff --git a/onnxscript/tests/functions/onnxfns2_test.py b/onnxscript/tests/functions/onnxfns2_test.py index db4aa40ea7..55ccaeea17 100644 --- a/onnxscript/tests/functions/onnxfns2_test.py +++ b/onnxscript/tests/functions/onnxfns2_test.py @@ -3,7 +3,8 @@ import onnxruntime import pytest -from onnxscript.tests.common import onnx_script_test_case, version_utils +from onnxscript._internal import version_utils +from onnxscript.tests.common import onnx_script_test_case from onnxscript.tests.models import onnxfns2 diff --git a/onnxscript/tests/functions/onnxfns_test.py b/onnxscript/tests/functions/onnxfns_test.py index fd0c5ff3a8..e0b37fb0f4 100644 --- a/onnxscript/tests/functions/onnxfns_test.py +++ b/onnxscript/tests/functions/onnxfns_test.py @@ -8,7 +8,8 @@ import onnx import pytest -from onnxscript.tests.common import onnx_script_test_case, version_utils +from onnxscript._internal import version_utils +from onnxscript.tests.common import onnx_script_test_case from onnxscript.tests.models import onnxfns1 From 268ad8f142194994c5d920ab10a0e6daee9b3914 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 17:53:18 +0000 Subject: [PATCH 02/22] Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder) [ghstack-poisoned] --- onnxscript/converter.py | 1 + onnxscript/irbuilder.py | 66 +++++++++++++++++++++-------------- onnxscript/type_annotation.py | 4 +-- onnxscript/values.py | 41 ++++++++++++---------- 4 files changed, 64 insertions(+), 48 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 278989a6fc..75f3a8dafb 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1326,6 +1326,7 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: self.ir_builder.add_attr_parameter( self.current_fn, x.arg, + ta.pytype_to_attrtype(typeinfo), default_value, ) self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x))) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 2481b42ae6..bcb2f9d3f6 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -108,10 +108,16 @@ def _opt_var_to_str(x): class IRAttributeValue: - """An attribute value (representing an actual parameter).""" + """An attribute value (representing an actual parameter). - def __init__(self, attrproto) -> None: + Attributes: + attr_proto: The attribute proto + has_default: Whether the attribute has a default value. + """ + + def __init__(self, attrproto, has_default: bool) -> None: self.attr_proto = attrproto + self.has_default = has_default def __str__(self): if self.attr_proto.HasField("ref_attr_name"): @@ -191,9 +197,7 @@ def __init__(self, name: str, domain: str = "") -> None: self.outputs: list[IRVar] = [] self.stmts: list[IRStmt] = [] # attribute parameters - self.attrs: list[str] = [] - # attribute parameters with default value - self.attr_protos: list[IRAttributeValue] = [] + self.attrs: list[IRAttributeValue] = [] self.called_functions: dict[str, onnx.FunctionProto] = {} self.docstring: str = "" # a dictionary of nested function-definitions @@ -207,11 +211,10 @@ def assigned_names(self) -> Sequence[str]: def __str__(self): attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else "" - attr_protos = _format(self.attr_protos, "<", ", ", ">") if self.attr_protos else "" inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")") outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")") stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n") - return f"{self.name} {attrs}{attr_protos}{inputs} => {outputs}{stmts}" + return f"{self.name} {attrs}{inputs} => {outputs}{stmts}" def append_docstring(self, docstring): self.docstring += docstring @@ -225,11 +228,8 @@ def append_input(self, name: IRVar) -> None: def append_output(self, name: IRVar) -> None: self.outputs.append(name) - def add_attr_parameter(self, attr: str | IRAttributeValue) -> None: - if isinstance(attr, IRAttributeValue): - self.attr_protos.append(attr) - else: - self.attrs.append(attr) + def add_attr_parameter(self, attr: IRAttributeValue) -> None: + self.attrs.append(attr) def debug_print(self): if logger.isEnabledFor(logging.DEBUG): @@ -398,19 +398,19 @@ def to_function_proto(self) -> onnx.FunctionProto: onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() ] - # attribute_proto is introduced in version onnx==1.13.0. + # attribute_proto is introduced in version onnx==1.14.0. # If this attribute is available, onnxscript uses it to # default values for attributes. The function has then two # lists, one list for attributes without default values, # another one for attributes with default values. # If this *attribute_proto* is not available, - # all attributes with a default value are moved to the first + # all attributes are moved to the first # list, default values are removed. # TODO: remove this when onnx with attribute_proto is released. if hasattr(onnx.FunctionProto, "attribute_proto"): - atts = self.attrs + attribute_names = [attr.name for attr in self.attrs if not attr.has_default] else: - atts = self.attrs + [a.attr_proto.name for a in self.attr_protos] + attribute_names = [attr.name for attr in self.attrs] f = helper.make_function( self.domain, @@ -419,11 +419,13 @@ def to_function_proto(self) -> onnx.FunctionProto: outputs=[y.name for y in self.outputs], nodes=nodes, opset_imports=opset_imports, # TODO - attributes=atts, + attributes=attribute_names, doc_string=self.docstring, ) if hasattr(onnx.FunctionProto, "attribute_proto"): - f.attribute_proto.extend([a.attr_proto for a in self.attr_protos]) + f.attribute_proto.extend( + [attr.attr_proto for attr in self.attrs if attr.has_default] + ) return f @@ -463,25 +465,35 @@ def add_input( v = IRVar(varname, type, info) fn.append_input(v) - def add_attr_parameter(self, fn: IRFunction, varname: str, default_value) -> None: + def add_attr_parameter( + self, + fn: IRFunction, + varname: str, + attribute_type: onnx.AttributeProto.AttributeType, + default_value, + ) -> None: if default_value is not None: - a = IRAttributeValue(helper.make_attribute(varname, default_value)) - fn.add_attr_parameter(a) + fn.add_attr_parameter( + IRAttributeValue( + helper.make_attribute(varname, default_value), has_default=True + ) + ) else: - fn.add_attr_parameter(varname) + proto = onnx.AttributeProto() + proto.name = varname + proto.type = attribute_type + fn.add_attr_parameter(IRAttributeValue(proto, has_default=False)) def add_output(self, fn: IRFunction, varname: str, type, info) -> None: v = IRVar(varname, type, info) fn.append_output(v) def make_attr(self, attrname: str, attrval: Any) -> IRAttributeValue: - return IRAttributeValue(helper.make_attribute(attrname, attrval)) + return IRAttributeValue(helper.make_attribute(attrname, attrval), has_default=True) def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue: a = onnx.AttributeProto() a.name = attrname a.ref_attr_name = refname - type_ = ta.pytype_to_attrtype(pytype) - assert type_ is not None - a.type = type_ - return IRAttributeValue(a) + a.type = ta.pytype_to_attrtype(pytype) + return IRAttributeValue(a, has_default=False) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index ee7a2cc554..c5e9f0a704 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -59,7 +59,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> typing.Optional[onnx.AttributeProto.AttributeType]: +) -> onnx.AttributeProto.AttributeType: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] @@ -74,7 +74,7 @@ def pytype_to_attrtype( elt_type = get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return None + return onnx.AttributeProto.UNDEFINED def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: diff --git a/onnxscript/values.py b/onnxscript/values.py index 04379685bd..121cac2d8a 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -311,11 +311,10 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # The first len(func_ir.inputs) arguments are onnx inputs inputs = function_ir.inputs # The rest is onnx attributes - attributes = function_ir.attrs # Construct a dictionary of attributes with their names specified in the function # definition attr_name_to_protos = collections.OrderedDict( - (attr.name, attr) for attr in function_ir.attr_protos + (attr.name, attr) for attr in function_ir.attrs ) # args with default value are attributes @@ -325,26 +324,30 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: required = False else: required = True - param_schema = ParamSchema( - name=arg.name, type=arg.typeinfo, is_input=True, required=required + schemas.append( + ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required) ) - schemas.append(param_schema) - - for attr_name in attributes: - # Attributes without default values - # FIXME(justinchuby): Where can we find the type? - param_schema = ParamSchema(name=attr_name, type=None, is_input=False) - schemas.append(param_schema) for name, attr_value in attr_name_to_protos.items(): - param_schema = ParamSchema( - name=name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], - default=_get_attribute_value(attr_value.attr_proto), - is_input=False, - # All function attributes are required - ) - schemas.append(param_schema) + if not attr_value.has_default: + schemas.append( + ParamSchema( + name=name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], + is_input=False, + required=True, + ) + ) + else: + schemas.append( + ParamSchema( + name=name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], + default=_get_attribute_value(attr_value.attr_proto), + is_input=False, + required=True, + ) + ) self._param_schemas = tuple(schemas) return self._param_schemas # type: ignore[return-value] From 1b7391e5222bcaabd48e3f5f77c81042f0858cab Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 22:08:18 +0000 Subject: [PATCH 03/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by adding a `has_default` field in `IRAttributeValue`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From ac60dce931e6e062fb53ad4d33505d7208cf4803 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 22:10:07 +0000 Subject: [PATCH 04/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by adding a `has_default` field in `IRAttributeValue`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From f7b9b79c324feab1c2d26f8e8de3937af6f59070 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 04:56:58 +0000 Subject: [PATCH 05/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by adding a `has_default` field in `IRAttributeValue`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From 758ce5921316d24dce26ccccce87391b832e7fcf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 15:42:08 +0000 Subject: [PATCH 06/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by adding a `has_default` field in `IRAttributeValue`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From d4b0db613881370bcd877d5d0a597725ab0d4cd5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 15:46:39 +0000 Subject: [PATCH 07/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by changing its type to `IRAttributeParameter`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From 8ffd7e7e46e051aaa0097e7991b70d170d62c86a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 15:50:27 +0000 Subject: [PATCH 08/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by changing its type to `IRAttributeParameter`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From 1fd4d2eefa406bf9e451ddc809ad818104730079 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 16:00:29 +0000 Subject: [PATCH 09/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by changing its type to `IRAttributeParameter`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From 9c6bfae86fb45786cbff2ba24aedfb32c224778e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 23:37:20 +0000 Subject: [PATCH 10/22] Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)" Merge the two list in `IRFunction` by changing its type to `IRAttributeParameter`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`. [ghstack-poisoned] From 4e38a29f2bf6c87f189173e7d5b70ece85c1e81b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Apr 2023 01:41:23 +0000 Subject: [PATCH 11/22] Analyze type annotation for input types | feat(op_schema) [ghstack-poisoned] --- onnxscript/onnx_types.py | 87 ++++++++++++++++++++++++++++------- onnxscript/type_annotation.py | 64 ++++++++++++++++++++++++-- 2 files changed, 130 insertions(+), 21 deletions(-) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 32e48bf744..0e8f8b8b8c 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -105,69 +105,105 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = [cls.shape] # example: "FLOAT[10]" return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + @classmethod + def to_string(cls) -> str: + raise NotImplementedError() + class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): - pass + @classmethod + def to_string(cls): + return "tensor(float)" class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): - pass + @classmethod + def to_string(cls): + return "tensor(uint8)" class INT8(TensorType, dtype=onnx.TensorProto.INT8): - pass + @classmethod + def to_string(cls): + return "tensor(int8)" class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): - pass + @classmethod + def to_string(cls): + return "tensor(uint16)" class INT16(TensorType, dtype=onnx.TensorProto.INT16): - pass + @classmethod + def to_string(cls): + return "tensor(int16)" class INT32(TensorType, dtype=onnx.TensorProto.INT32): - pass + @classmethod + def to_string(cls): + return "tensor(int32)" class INT64(TensorType, dtype=onnx.TensorProto.INT64): - pass + @classmethod + def to_string(cls): + return "tensor(int64)" class STRING(TensorType, dtype=onnx.TensorProto.STRING): - pass + @classmethod + def to_string(cls): + return "tensor(string)" class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): - pass + @classmethod + def to_string(cls): + return "tensor(bool)" class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): - pass + @classmethod + def to_string(cls): + return "tensor(float16)" class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): - pass + @classmethod + def to_string(cls): + return "tensor(double)" class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): - pass + @classmethod + def to_string(cls): + return "tensor(uint32)" class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): - pass + @classmethod + def to_string(cls): + return "tensor(uint64)" class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): - pass + @classmethod + def to_string(cls): + return "tensor(complex64)" class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): - pass + @classmethod + def to_string(cls): + return "tensor(complex128)" class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): - pass + @classmethod + def to_string(cls): + return "tensor(bfloat16)" def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: @@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: # Currently, only tensor types are supported. Need to expand support for other ONNX types. ONNXType = TensorType + +ALL_TENSOR_TYPES = ( + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, +) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index c5e9f0a704..2afd54b5dd 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -7,11 +7,12 @@ import collections import inspect import typing +from typing import Any, TypeVar, Union import onnx from typing_extensions import get_args, get_origin -from onnxscript.onnx_types import TensorType +from onnxscript import onnx_types # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports @@ -59,7 +60,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> onnx.AttributeProto.AttributeType: +) -> typing.Optional[onnx.AttributeProto.AttributeType]: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] @@ -74,13 +75,13 @@ def pytype_to_attrtype( elt_type = get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return onnx.AttributeProto.UNDEFINED + return None def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: - if isinstance(typeinfo, TensorType): + if isinstance(typeinfo, onnx_types.TensorType): return True - if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType): + if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False @@ -146,3 +147,56 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[ if get_origin(typeinfo) is tuple: return get_args(typeinfo) return (typeinfo,) + + +def _reduce_type_var_to_union(hint: typing.TypeVar): + """Reduce a TypeVar to a Union type on which we can use issubclass to check membership.""" + assert isinstance(hint, TypeVar) + + # If the TypeVar has a bound, use that. + if hint.__bound__ is not None: + return hint.__bound__ + + # If the TypeVar has no bound, use the first constraint. + if hint.__constraints__: + return Union.__getitem__(hint.__constraints__) + + return Any + + +def get_supported_input_types(pytype) -> list[str]: + """Returns a list of all supported input types for a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + A list of all supported input types for the given type annotation. + """ + # TODO: Change this to + supported_types: list[str] = [] + if typing.get_origin(pytype) is Union and isinstance(typing.get_args(pytype)[0], TypeVar): + # Recursively unpack TypeVars inside an Optional + for arg in typing.get_args(pytype): + supported_types.extend(get_supported_input_types(arg)) + return supported_types + + if isinstance(pytype, TypeVar): + pytype = _reduce_type_var_to_union(pytype) + + for tensor_type in onnx_types.ALL_TENSOR_TYPES: + if pytype is None: + # The same as Any + supported_types.append(tensor_type.to_string()) + elif pytype == onnx_types.TensorType: + supported_types.append(tensor_type.to_string()) + elif isinstance(pytype, tensor_type): + supported_types.append(tensor_type.to_string()) + elif issubclass(tensor_type, pytype): + supported_types.append(tensor_type.to_string()) + # TODO(justinchuby): Handle sequence types + + return supported_types + + +def get_type_var_name From b34551c5c0eeb6991f91cba8e4b1e7a3790977cd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Apr 2023 02:24:06 +0000 Subject: [PATCH 12/22] Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] --- onnxscript/type_annotation.py | 22 +++++++++++----------- onnxscript/type_annotation_test.py | 6 +++++- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 2afd54b5dd..b6034ad506 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -16,13 +16,13 @@ # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports -# * float, int, str (primitive attribute types) -# * Sequence[float], Sequence[int], Sequence[str] (attribute types) -# * Tensor types -# * Sequence[Tensor] types -# * Union of above 2 -# * TypeVars with above bounds -# * Above types with annotation attached +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached TypeAnnotationValue = typing.Any # Map from python type to corresponding ONNX AttributeProto type @@ -164,8 +164,8 @@ def _reduce_type_var_to_union(hint: typing.TypeVar): return Any -def get_supported_input_types(pytype) -> list[str]: - """Returns a list of all supported input types for a given type annotation. +def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: + """Returns a list of all supported input types in string representation for a given type annotation. Args: pytype: A type annotation. @@ -173,7 +173,6 @@ def get_supported_input_types(pytype) -> list[str]: Returns: A list of all supported input types for the given type annotation. """ - # TODO: Change this to supported_types: list[str] = [] if typing.get_origin(pytype) is Union and isinstance(typing.get_args(pytype)[0], TypeVar): # Recursively unpack TypeVars inside an Optional @@ -199,4 +198,5 @@ def get_supported_input_types(pytype) -> list[str]: return supported_types -def get_type_var_name +def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: + pass diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 2d8ce5bf80..a2d5e79ebb 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -14,7 +14,7 @@ from onnxscript.tests.common import testutils -class TypeAnnotationTester(testutils.TestBase): +class TypeAnnotationTest(testutils.TestBase): def test_type_annotation(self): """Test type annotations.""" @@ -92,5 +92,9 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: self.assertSameFunction(bool_type_for_attribute, bool_type_for_attribute_txt) +class UtilityFunctionsTest(unittest.TestCase): + def test_pytype_to_input_strings(self): + pass + if __name__ == "__main__": unittest.main() From 70146585509247b8f14a01feb9aa6ead13368be8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Apr 2023 00:28:30 +0000 Subject: [PATCH 13/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From a80742f0d548af968e9e96fff1a076a5cab20dee Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Apr 2023 16:23:06 +0000 Subject: [PATCH 14/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From f2046cc0c16b88977222752e2b47a7dc6e8916ca Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Apr 2023 19:16:16 +0000 Subject: [PATCH 15/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From 84f70141e589edbfc75a254a16f248aba3e69ce0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 20 Apr 2023 17:16:56 +0000 Subject: [PATCH 16/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From 03f64d404cde3be8745e67993835a35ab15143ea Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 20 Apr 2023 17:35:56 +0000 Subject: [PATCH 17/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From 8648dd3c1888a682eedc241d33ab66cf4d72b977 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Apr 2023 00:23:17 +0000 Subject: [PATCH 18/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From 5f04381f9a40bb4bddc5093a0af66a2e02f810ea Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Apr 2023 00:25:18 +0000 Subject: [PATCH 19/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From d272e202fe0156d41efa4988c0919af341973f38 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Apr 2023 00:30:28 +0000 Subject: [PATCH 20/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From e61bffee11c219ef2faa04fa5ddddcdee58acdd0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Apr 2023 04:08:27 +0000 Subject: [PATCH 21/22] Update base for Update on "Analyze type annotation for input types | feat(op_schema)" [ghstack-poisoned] From 0f76ccf95ab99c59b3f9fa2826df85f8d2628dfd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Apr 2023 04:08:27 +0000 Subject: [PATCH 22/22] Implement split_function_and_wrangler | test(torchlib) [ghstack-poisoned] --- .../tests/function_libs/torch_lib/ops_test.py | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 98e89d509b..140f2118cc 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -17,7 +17,7 @@ import unittest import warnings -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence import numpy as np import onnx @@ -55,14 +55,23 @@ def _should_skip_test_sample(op_name: str, sample) -> Optional[str]: return None +def _split_function_and_wrangler( + onnx_function_and_wrangler: Callable[..., Any] + | tuple[Callable[..., Any], Callable[..., Any]] +) -> tuple[Callable[..., Any], Callable[..., Any] | None]: + """Splits a function with an optional input wrangler into a function and an input wrangler.""" + if isinstance(onnx_function_and_wrangler, tuple): + return onnx_function_and_wrangler + + assert callable(onnx_function_and_wrangler) + return onnx_function_and_wrangler, None + + class TestFunctionValidity(unittest.TestCase): def test_all_script_functions_are_onnx_functions(self): functions = set() for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.values(): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) functions.add(func) # TODO(justinchuby): Add from the registry @@ -76,10 +85,7 @@ def test_all_script_functions_are_onnx_functions(self): def test_all_trace_only_functions_are_not_onnx_functions(self): for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.values(): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) if isinstance(func, onnxscript.OnnxFunction): raise AssertionError( f"'{func.name}' is an OnnxFunction. " @@ -95,10 +101,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self): "Function checker is not available before ONNX 1.14", ) def test_script_function_passes_checker(self, _, func_with_wrangler): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) function_proto = func.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @@ -127,16 +130,11 @@ def run_test_output_match( ) onnx_function_and_wrangler = ops_test_data.OPINFO_FUNCTION_MAPPING[op.name] - input_wrangler = None - if isinstance(onnx_function_and_wrangler, tuple): - # Obtain the input_wrangler that manipulates the OpInfo inputs - # to match the aten operator signature - # An example is nn.functional.upsample_nearest2d, which has a different signature - # than the aten operator upsample_nearest2d - onnx_function, input_wrangler = onnx_function_and_wrangler - else: - assert callable(onnx_function_and_wrangler) - onnx_function = onnx_function_and_wrangler + # Obtain the input_wrangler that manipulates the OpInfo inputs + # to match the aten operator signature + # An example is nn.functional.upsample_nearest2d, which has a different signature + # than the aten operator upsample_nearest2d + onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler) for i, cpu_sample in enumerate(samples): inputs = (cpu_sample.input, *cpu_sample.args)