-
Notifications
You must be signed in to change notification settings - Fork 72
Auto generate OpSchema for functions | feat(op_schema) #626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c6505cf
268ad8f
327a7d6
95c4ba6
0883d6f
a353d73
962c13b
fa16ca5
821821c
46fa00f
6b9106b
43345e2
bbe8e7e
f835d9b
26d5caa
e953774
b3a035d
7fda2d1
03cc7f4
b760abc
7321b22
606be97
efc7708
a3f9b50
67a8ee0
08a27fa
1d4a0b4
28b4a48
2d158ac
19f9484
622b688
b61543a
31bb69c
6cfa67c
2c6be92
0502482
a9a0845
ad57790
79d3605
f2455dc
1439aaf
a2fde87
ea41c8f
b334880
138c2ed
cef03af
90208e1
ed79fce
2d9627e
6376c93
8f5f7ba
dd80bff
49d8d0e
14d2149
b3dbb7f
14a61f7
a431c7a
556cb95
3e10d4b
b6a5df0
d36548e
b34b2bd
e1782a7
246aac4
05d7b9e
e2c22ab
9e0ff7f
147e34d
5a4c9f1
c7cf1e8
1159e42
f5971d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,14 @@ | |
import logging | ||
import types | ||
from enum import IntFlag | ||
from typing import Any, Optional, Sequence, _GenericAlias # type: ignore[attr-defined] | ||
from typing import _GenericAlias # type: ignore[attr-defined] | ||
from typing import Any, Optional, Sequence | ||
|
||
import onnx | ||
import onnx.defs | ||
|
||
from onnxscript import irbuilder, sourceinfo | ||
from onnxscript import irbuilder, sourceinfo, type_annotation | ||
from onnxscript._internal import version_utils | ||
|
||
_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { | ||
onnx.defs.OpSchema.AttrType.FLOAT: float, | ||
|
@@ -34,6 +36,7 @@ | |
|
||
# A special value to indicate that the default value is not specified | ||
_EmptyDefault = object() | ||
_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14") | ||
|
||
|
||
class Opset: | ||
|
@@ -173,7 +176,7 @@ def __init__( | |
) -> None: | ||
self.opset = opset | ||
self.opname = opname | ||
self.opschema = opschema | ||
self._opschema = opschema | ||
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None | ||
|
||
def __call__(self, *args, **kwargs): | ||
|
@@ -190,9 +193,13 @@ def __call__(self, *args, **kwargs): | |
def is_single_op(self) -> bool: | ||
return isinstance(self.opname, str) | ||
|
||
@property | ||
def opschema(self) -> Optional[onnx.defs.OpSchema]: | ||
return self._opschema | ||
|
||
def get_schema(self) -> Optional[onnx.defs.OpSchema]: | ||
"""Returns the ONNX OpSchema for this op.""" | ||
if self.opschema: | ||
if self.opschema is not None: | ||
return self.opschema | ||
return self.opset[self.opname] | ||
|
||
|
@@ -249,6 +256,100 @@ class OnnxClosure: | |
function: Any | ||
|
||
|
||
@dataclasses.dataclass | ||
class TypeConstraint: | ||
"""Represents a type constraint for an ONNX op. | ||
|
||
Attributes: | ||
name: The name of the type constraint. | ||
allowed_types: The allowed types for the type constraint. | ||
""" | ||
|
||
name: str | ||
allowed_types: list[str] | ||
description: str = "" | ||
|
||
def as_tuple(self) -> tuple[str, list[str], str]: | ||
"""Returns the type constraint as a tuple.""" | ||
return (self.name, self.allowed_types, self.description) | ||
|
||
|
||
def op_schema_from_function_ir( | ||
function_ir: irbuilder.IRFunction, opset: Opset | ||
) -> onnx.defs.OpSchema: | ||
"""Construct an ONNX OpSchema from an IRFunction.""" | ||
|
||
# Find all distinct types in the inputs and outputs | ||
distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union( | ||
{arg.typeinfo for arg in function_ir.outputs} | ||
) | ||
# Create a mapping from type to a unique name | ||
type_to_constraint = {} | ||
for i, type_ in enumerate(distinct_types): | ||
name = f"T{i}" | ||
type_to_constraint[type_] = TypeConstraint( | ||
name=type_annotation.get_type_constraint_name(type_) or name, | ||
allowed_types=type_annotation.pytype_to_type_strings(type_), | ||
) | ||
|
||
formal_inputs = [ | ||
onnx.defs.OpSchema.FormalParameter( | ||
arg.name, | ||
type_to_constraint[arg.typeinfo].name, | ||
param_option=( | ||
onnx.defs.OpSchema.FormalParameterOption.Optional | ||
if type_annotation.is_optional(arg.typeinfo) | ||
else onnx.defs.OpSchema.FormalParameterOption.Single | ||
), | ||
# TODO(justinchu): Check this is_homogeneous thing | ||
is_homogeneous=True, | ||
) | ||
for arg in function_ir.inputs | ||
] | ||
formal_outputs = [ | ||
onnx.defs.OpSchema.FormalParameter( | ||
arg.name, | ||
type_to_constraint[arg.typeinfo].name, | ||
param_option=( | ||
onnx.defs.OpSchema.FormalParameterOption.Optional | ||
if type_annotation.is_optional(arg.typeinfo) | ||
else onnx.defs.OpSchema.FormalParameterOption.Single | ||
), | ||
# TODO(justinchu): Check this is_homogeneous thing | ||
is_homogeneous=True, | ||
) | ||
for arg in function_ir.outputs | ||
] | ||
|
||
return onnx.defs.OpSchema( | ||
function_ir.name, | ||
opset.domain, | ||
since_version=opset.version, | ||
doc=function_ir.docstring, | ||
inputs=formal_inputs, | ||
outputs=formal_outputs, | ||
type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], | ||
attributes=[ | ||
*[ | ||
onnx.defs.OpSchema.Attribute( | ||
attr.name, | ||
type=onnx.defs.OpSchema.AttrType(attr.type), | ||
) | ||
for attr in function_ir.attrs | ||
if not attr.has_default | ||
], | ||
*[ | ||
onnx.defs.OpSchema.Attribute( | ||
attr.name, | ||
default_value=attr.attr_proto, | ||
) | ||
for attr in function_ir.attrs | ||
if attr.has_default | ||
], | ||
], | ||
) | ||
|
||
|
||
class OnnxFunction(Op): | ||
"""Represents an ONNX op for which a function-body has been defined in onnxscript. | ||
|
||
|
@@ -276,12 +377,26 @@ def __init__( | |
self.source = source | ||
self.kwargs = kwargs | ||
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None | ||
self._opschema: Optional[onnx.defs.OpSchema] = None | ||
|
||
@property | ||
def name(self): | ||
"""Returns the function name.""" | ||
return self.opname | ||
|
||
@property | ||
def opschema(self) -> Optional[onnx.defs.OpSchema]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this being used anywhere yet? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. It will be used by the exporter. For now I added tests to make sure all torch_lib functions have opschemas defined There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry the tests were left out for some reason. I added them back |
||
"""Construct an OpSchema from function_ir.""" | ||
if self._opschema is not None: | ||
return self._opschema | ||
|
||
if not _ONNX_OP_SCHEMA_WRITABLE: | ||
return None | ||
|
||
self._opschema = op_schema_from_function_ir(self.function_ir, self.opset) | ||
|
||
return self._opschema | ||
|
||
def __getitem__(self, instance): | ||
"""Returns a lambda to evaluate function using given evaluator instance. | ||
|
||
|
@@ -311,6 +426,9 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: | |
if self._param_schemas is not None: | ||
return self._param_schemas | ||
|
||
# NOTE: We generate the parameter schemas from the function_ir instead | ||
# of relying on the auto generated OpSchema because we need to preserve the keyword | ||
# argument order from the Python function definition, which is lost in OpSchema. | ||
function_ir = self.function_ir | ||
# The first len(func_ir.inputs) arguments are onnx inputs | ||
inputs = function_ir.inputs | ||
|
Check notice
Code scanning / CodeQL
Cyclic import