Skip to content

Auto generate OpSchema for functions #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor:
return cast_inputs(get_type_info, cast, op_schema, *args)


def static_cast_inputs(converter, op_schema: OpSchema, *args):
def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args):
"""Used for autocast during script-translation."""
if op_schema is None:
return args
Expand Down
1 change: 1 addition & 0 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,7 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
self.ir_builder.add_attr_parameter(
self.current_fn,
x.arg,
ta.pytype_to_attrtype(typeinfo),
default_value,
)
self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x)))
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_aten/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import onnxscript.testing
from onnxscript import FLOAT, evaluator
from onnxscript import opset17 as op
from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_aten import graph_building, ops
from onnxscript.tests.common import version_utils


@unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported")
Expand Down
38 changes: 26 additions & 12 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self, name: str, domain: str = "") -> None:
self.outputs: list[IRVar] = []
self.stmts: list[IRStmt] = []
# attribute parameters
self.attrs: list[str] = []
self.attrs: list[IRAttributeValue] = []
# attribute parameters with default value
self.attr_protos: list[IRAttributeValue] = []
self.called_functions: dict[str, onnx.FunctionProto] = {}
Expand Down Expand Up @@ -225,8 +225,8 @@ def append_input(self, name: IRVar) -> None:
def append_output(self, name: IRVar) -> None:
self.outputs.append(name)

def add_attr_parameter(self, attr: str | IRAttributeValue) -> None:
if isinstance(attr, IRAttributeValue):
def add_attr_parameter(self, attr: IRAttributeValue, has_default: bool) -> None:
if has_default:
self.attr_protos.append(attr)
else:
self.attrs.append(attr)
Expand Down Expand Up @@ -398,7 +398,7 @@ def to_function_proto(self) -> onnx.FunctionProto:
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
]

# attribute_proto is introduced in version onnx==1.13.0.
# attribute_proto is introduced in version onnx==1.14.0.
# If this attribute is available, onnxscript uses it to
# default values for attributes. The function has then two
# lists, one list for attributes without default values,
Expand All @@ -408,9 +408,12 @@ def to_function_proto(self) -> onnx.FunctionProto:
# list, default values are removed.
# TODO: remove this when onnx with attribute_proto is released.
if hasattr(onnx.FunctionProto, "attribute_proto"):
atts = self.attrs
attributes = [attr.name for attr in self.attrs]
else:
atts = self.attrs + [a.attr_proto.name for a in self.attr_protos]
attributes = [
*[attr.name for attr in self.attrs],
*[a.attr_proto.name for a in self.attr_protos],
]

f = helper.make_function(
self.domain,
Expand All @@ -419,11 +422,11 @@ def to_function_proto(self) -> onnx.FunctionProto:
outputs=[y.name for y in self.outputs],
nodes=nodes,
opset_imports=opset_imports, # TODO
attributes=atts,
attributes=attributes,
doc_string=self.docstring,
)
if hasattr(onnx.FunctionProto, "attribute_proto"):
f.attribute_proto.extend([a.attr_proto for a in self.attr_protos])
f.attribute_proto.extend([attr.attr_proto for attr in self.attr_protos])
return f


Expand Down Expand Up @@ -463,12 +466,23 @@ def add_input(
v = IRVar(varname, type, info)
fn.append_input(v)

def add_attr_parameter(self, fn: IRFunction, varname: str, default_value) -> None:
def add_attr_parameter(
self,
fn: IRFunction,
varname: str,
attribute_type: onnx.AttributeProto.AttributeType,
default_value: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better annotation we can use than Any?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Let me refine that

) -> None:
if default_value is not None:
a = IRAttributeValue(helper.make_attribute(varname, default_value))
fn.add_attr_parameter(a)
fn.add_attr_parameter(
IRAttributeValue(helper.make_attribute(varname, default_value)),
has_default=True,
)
else:
fn.add_attr_parameter(varname)
proto = onnx.AttributeProto()
proto.name = varname
proto.type = attribute_type
fn.add_attr_parameter(IRAttributeValue(proto), has_default=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think has_defualt can be an attribute of IRAttributeValue? I think that's more straightforward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great suggestion - I thought of this too. Will make an update


def add_output(self, fn: IRFunction, varname: str, type, info) -> None:
v = IRVar(varname, type, info)
Expand Down
87 changes: 71 additions & 16 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,69 +105,105 @@ def to_type_proto(cls) -> onnx.TypeProto:
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)

@classmethod
def to_string(cls) -> str:
raise NotImplementedError()


class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
pass
@classmethod
def to_string(cls):
return "tensor(float)"


class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
pass
@classmethod
def to_string(cls):
return "tensor(uint8)"


class INT8(TensorType, dtype=onnx.TensorProto.INT8):
pass
@classmethod
def to_string(cls):
return "tensor(int8)"


class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
pass
@classmethod
def to_string(cls):
return "tensor(uint16)"


class INT16(TensorType, dtype=onnx.TensorProto.INT16):
pass
@classmethod
def to_string(cls):
return "tensor(int16)"


class INT32(TensorType, dtype=onnx.TensorProto.INT32):
pass
@classmethod
def to_string(cls):
return "tensor(int32)"


class INT64(TensorType, dtype=onnx.TensorProto.INT64):
pass
@classmethod
def to_string(cls):
return "tensor(int64)"


class STRING(TensorType, dtype=onnx.TensorProto.STRING):
pass
@classmethod
def to_string(cls):
return "tensor(string)"


class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
pass
@classmethod
def to_string(cls):
return "tensor(bool)"


class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
pass
@classmethod
def to_string(cls):
return "tensor(float16)"


class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
pass
@classmethod
def to_string(cls):
return "tensor(double)"


class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
pass
@classmethod
def to_string(cls):
return "tensor(uint32)"


class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
pass
@classmethod
def to_string(cls):
return "tensor(uint64)"


class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
pass
@classmethod
def to_string(cls):
return "tensor(complex64)"


class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
pass
@classmethod
def to_string(cls):
return "tensor(complex128)"


class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
pass
@classmethod
def to_string(cls):
return "tensor(bfloat16)"


def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
Expand Down Expand Up @@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:

# Currently, only tensor types are supported. Need to expand support for other ONNX types.
ONNXType = TensorType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. Maybe?


ALL_TENSOR_TYPES = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please sort them in alphabetical order for better readness?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

FLOAT,
UINT8,
INT8,
UINT16,
INT16,
INT32,
INT64,
STRING,
BOOL,
FLOAT16,
DOUBLE,
UINT32,
UINT64,
COMPLEX64,
COMPLEX128,
BFLOAT16,
)
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@

import onnxscript
import onnxscript.evaluator
from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_aten import graph_building
from onnxscript.function_libs.torch_aten.ops import core as core_ops
from onnxscript.function_libs.torch_aten.ops import nn as nn_ops
from onnxscript.function_libs.torch_aten.ops import special as special_ops
from onnxscript.tests.common import version_utils
from onnxscript.tests.function_libs.torch_aten import extra_opinfo

T = TypeVar("T")
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/tests/functions/onnxfns1A_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import onnx
import pytest

from onnxscript.tests.common import onnx_script_test_case, version_utils
from onnxscript._internal import version_utils
from onnxscript.tests.common import onnx_script_test_case
from onnxscript.tests.models import onnxfns1A


Expand Down
3 changes: 2 additions & 1 deletion onnxscript/tests/functions/onnxfns2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import onnxruntime
import pytest

from onnxscript.tests.common import onnx_script_test_case, version_utils
from onnxscript._internal import version_utils
from onnxscript.tests.common import onnx_script_test_case
from onnxscript.tests.models import onnxfns2


Expand Down
3 changes: 2 additions & 1 deletion onnxscript/tests/functions/onnxfns_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import onnx
import pytest

from onnxscript.tests.common import onnx_script_test_case, version_utils
from onnxscript._internal import version_utils
from onnxscript.tests.common import onnx_script_test_case
from onnxscript.tests.models import onnxfns1


Expand Down
56 changes: 53 additions & 3 deletions onnxscript/type_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import collections
import inspect
import typing
from typing import Any, TypeVar, Union

import onnx
from typing_extensions import get_args, get_origin

from onnxscript.onnx_types import TensorType
from onnxscript import onnx_types

# TypeAnnotationValue represents the (value of) valid type-annotations recognized
# by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports
Expand Down Expand Up @@ -78,9 +79,9 @@ def pytype_to_attrtype(


def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool:
if isinstance(typeinfo, TensorType):
if isinstance(typeinfo, onnx_types.TensorType):
return True
if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType):
if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType):
return True
return False

Expand Down Expand Up @@ -146,3 +147,52 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[
if get_origin(typeinfo) is tuple:
return get_args(typeinfo)
return (typeinfo,)


def _reduce_type_var_to_union(hint: typing.TypeVar):
"""Reduce a TypeVar to a Union type on which we can use issubclass to check membership."""
assert isinstance(hint, TypeVar)

# If the TypeVar has a bound, use that.
if hint.__bound__ is not None:
return hint.__bound__

# If the TypeVar has no bound, use the first constraint.
if hint.__constraints__:
return Union.__getitem__(hint.__constraints__)

return Any


def get_supported_input_types(pytype) -> list[str]:
"""Returns a list of all supported input types for a given type annotation.

Args:
pytype: A type annotation.

Returns:
A list of all supported input types for the given type annotation.
"""
supported_types: list[str] = []
if typing.get_origin(pytype) is Union and isinstance(typing.get_args(pytype)[0], TypeVar):
# Recursively unpack TypeVars inside an Optional
for arg in typing.get_args(pytype):
supported_types.extend(get_supported_input_types(arg))
return supported_types

if isinstance(pytype, TypeVar):
pytype = _reduce_type_var_to_union(pytype)

for tensor_type in onnx_types.ALL_TENSOR_TYPES:
if pytype is None:
# The same as Any
supported_types.append(tensor_type.to_string())
elif pytype == onnx_types.TensorType:
supported_types.append(tensor_type.to_string())
elif isinstance(pytype, tensor_type):
supported_types.append(tensor_type.to_string())
elif issubclass(tensor_type, pytype):
supported_types.append(tensor_type.to_string())
# TODO(justinchuby): Handle sequence types

return supported_types
Loading