Skip to content

Auto generate OpSchema for functions | feat(op_schema) #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 72 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
c6505cf
Move version_utils to _internal | chore
justinchuby Apr 12, 2023
268ad8f
Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)
justinchuby Apr 12, 2023
327a7d6
Auto generate OpSchema for functions | feat
justinchuby Apr 12, 2023
95c4ba6
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
0883d6f
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
a353d73
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
962c13b
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
fa16ca5
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
821821c
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
46fa00f
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
6b9106b
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
43345e2
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
bbe8e7e
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
f835d9b
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
26d5caa
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
e953774
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
b3a035d
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
7fda2d1
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
03cc7f4
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
b760abc
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
7321b22
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
606be97
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 14, 2023
efc7708
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 14, 2023
a3f9b50
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 14, 2023
67a8ee0
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 14, 2023
08a27fa
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 14, 2023
1d4a0b4
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 14, 2023
28b4a48
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 17, 2023
2d158ac
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 17, 2023
19f9484
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
622b688
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
b61543a
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
31bb69c
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
6cfa67c
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
2c6be92
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
0502482
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
a9a0845
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
ad57790
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
79d3605
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
f2455dc
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
1439aaf
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
a2fde87
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 21, 2023
ea41c8f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 21, 2023
b334880
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
138c2ed
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
cef03af
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
90208e1
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
ed79fce
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
2d9627e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
6376c93
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
8f5f7ba
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
dd80bff
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 25, 2023
49d8d0e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 25, 2023
14d2149
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
b3dbb7f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
14a61f7
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
a431c7a
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
556cb95
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
3e10d4b
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
b6a5df0
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
d36548e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
b34b2bd
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
e1782a7
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
246aac4
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
05d7b9e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
e2c22ab
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
9e0ff7f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
147e34d
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
5a4c9f1
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
c7cf1e8
Merge branch 'main' into gh/justinchuby/16/head
justinchuby Apr 28, 2023
1159e42
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
f5971d2
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor:
return cast_inputs(get_type_info, cast, op_schema, *args)


def static_cast_inputs(converter, op_schema: OpSchema, *args):
def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args):
"""Used for autocast during script-translation."""
if op_schema is None:
return args
Expand Down
13 changes: 13 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ def test_script_function_passes_checker(self, _, func_with_wrangler):
function_proto = func.to_function_proto()
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]

@parameterized.parameterized.expand(
list(ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.items())
)
@unittest.skipIf(
version_utils.onnx_older_than("1.15"),
"OpSchema is not writable before ONNX 1.15",
)
def test_script_function_has_op_schema(self, _, func_with_wrangler):
func, _ = _split_function_and_wrangler(func_with_wrangler)
schema = func.opschema
self.assertIsNotNone(schema)
self.assertEqual(schema.name, func.name)


def run_test_output_match(
test_suite: unittest.TestCase,
Expand Down
126 changes: 122 additions & 4 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import logging
import types
from enum import IntFlag
from typing import Any, Optional, Sequence, _GenericAlias # type: ignore[attr-defined]
from typing import _GenericAlias # type: ignore[attr-defined]
from typing import Any, Optional, Sequence

import onnx
import onnx.defs

from onnxscript import irbuilder, sourceinfo
from onnxscript import irbuilder, sourceinfo, type_annotation

Check notice

Code scanning / CodeQL

Cyclic import

Import of module [onnxscript.irbuilder](1) begins an import cycle.
from onnxscript._internal import version_utils

_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
onnx.defs.OpSchema.AttrType.FLOAT: float,
Expand All @@ -34,6 +36,7 @@

# A special value to indicate that the default value is not specified
_EmptyDefault = object()
_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14")


class Opset:
Expand Down Expand Up @@ -173,7 +176,7 @@ def __init__(
) -> None:
self.opset = opset
self.opname = opname
self.opschema = opschema
self._opschema = opschema
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None

def __call__(self, *args, **kwargs):
Expand All @@ -190,9 +193,13 @@ def __call__(self, *args, **kwargs):
def is_single_op(self) -> bool:
return isinstance(self.opname, str)

@property
def opschema(self) -> Optional[onnx.defs.OpSchema]:
return self._opschema

def get_schema(self) -> Optional[onnx.defs.OpSchema]:
"""Returns the ONNX OpSchema for this op."""
if self.opschema:
if self.opschema is not None:
return self.opschema
return self.opset[self.opname]

Expand Down Expand Up @@ -249,6 +256,100 @@ class OnnxClosure:
function: Any


@dataclasses.dataclass
class TypeConstraint:
"""Represents a type constraint for an ONNX op.

Attributes:
name: The name of the type constraint.
allowed_types: The allowed types for the type constraint.
"""

name: str
allowed_types: list[str]
description: str = ""

def as_tuple(self) -> tuple[str, list[str], str]:
"""Returns the type constraint as a tuple."""
return (self.name, self.allowed_types, self.description)


def op_schema_from_function_ir(
function_ir: irbuilder.IRFunction, opset: Opset
) -> onnx.defs.OpSchema:
"""Construct an ONNX OpSchema from an IRFunction."""

# Find all distinct types in the inputs and outputs
distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union(
{arg.typeinfo for arg in function_ir.outputs}
)
# Create a mapping from type to a unique name
type_to_constraint = {}
for i, type_ in enumerate(distinct_types):
name = f"T{i}"
type_to_constraint[type_] = TypeConstraint(
name=type_annotation.get_type_constraint_name(type_) or name,
allowed_types=type_annotation.pytype_to_type_strings(type_),
)

formal_inputs = [
onnx.defs.OpSchema.FormalParameter(
arg.name,
type_to_constraint[arg.typeinfo].name,
param_option=(
onnx.defs.OpSchema.FormalParameterOption.Optional
if type_annotation.is_optional(arg.typeinfo)
else onnx.defs.OpSchema.FormalParameterOption.Single
),
# TODO(justinchu): Check this is_homogeneous thing
is_homogeneous=True,
)
for arg in function_ir.inputs
]
formal_outputs = [
onnx.defs.OpSchema.FormalParameter(
arg.name,
type_to_constraint[arg.typeinfo].name,
param_option=(
onnx.defs.OpSchema.FormalParameterOption.Optional
if type_annotation.is_optional(arg.typeinfo)
else onnx.defs.OpSchema.FormalParameterOption.Single
),
# TODO(justinchu): Check this is_homogeneous thing
is_homogeneous=True,
)
for arg in function_ir.outputs
]

return onnx.defs.OpSchema(
function_ir.name,
opset.domain,
since_version=opset.version,
doc=function_ir.docstring,
inputs=formal_inputs,
outputs=formal_outputs,
type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()],
attributes=[
*[
onnx.defs.OpSchema.Attribute(
attr.name,
type=onnx.defs.OpSchema.AttrType(attr.type),
)
for attr in function_ir.attrs
if not attr.has_default
],
*[
onnx.defs.OpSchema.Attribute(
attr.name,
default_value=attr.attr_proto,
)
for attr in function_ir.attrs
if attr.has_default
],
],
)


class OnnxFunction(Op):
"""Represents an ONNX op for which a function-body has been defined in onnxscript.

Expand Down Expand Up @@ -276,12 +377,26 @@ def __init__(
self.source = source
self.kwargs = kwargs
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
self._opschema: Optional[onnx.defs.OpSchema] = None

@property
def name(self):
"""Returns the function name."""
return self.opname

@property
def opschema(self) -> Optional[onnx.defs.OpSchema]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this being used anywhere yet?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. It will be used by the exporter. For now I added tests to make sure all torch_lib functions have opschemas defined

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry the tests were left out for some reason. I added them back

"""Construct an OpSchema from function_ir."""
if self._opschema is not None:
return self._opschema

if not _ONNX_OP_SCHEMA_WRITABLE:
return None

self._opschema = op_schema_from_function_ir(self.function_ir, self.opset)

return self._opschema

def __getitem__(self, instance):
"""Returns a lambda to evaluate function using given evaluator instance.

Expand Down Expand Up @@ -311,6 +426,9 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
if self._param_schemas is not None:
return self._param_schemas

# NOTE: We generate the parameter schemas from the function_ir instead
# of relying on the auto generated OpSchema because we need to preserve the keyword
# argument order from the Python function definition, which is lost in OpSchema.
function_ir = self.function_ir
# The first len(func_ir.inputs) arguments are onnx inputs
inputs = function_ir.inputs
Expand Down