Skip to content

Auto generate OpSchema for functions | feat(op_schema) #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 72 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
c6505cf
Move version_utils to _internal | chore
justinchuby Apr 12, 2023
268ad8f
Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)
justinchuby Apr 12, 2023
327a7d6
Auto generate OpSchema for functions | feat
justinchuby Apr 12, 2023
95c4ba6
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
0883d6f
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
a353d73
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
962c13b
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
fa16ca5
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
821821c
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
46fa00f
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
6b9106b
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
43345e2
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
bbe8e7e
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
f835d9b
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
26d5caa
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
e953774
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
b3a035d
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
7fda2d1
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
03cc7f4
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
b760abc
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
7321b22
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
606be97
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 14, 2023
efc7708
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 14, 2023
a3f9b50
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 14, 2023
67a8ee0
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 14, 2023
08a27fa
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 14, 2023
1d4a0b4
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 14, 2023
28b4a48
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 17, 2023
2d158ac
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 17, 2023
19f9484
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
622b688
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
b61543a
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
31bb69c
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
6cfa67c
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
2c6be92
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
0502482
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
a9a0845
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
ad57790
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
79d3605
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
f2455dc
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
1439aaf
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
a2fde87
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 21, 2023
ea41c8f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 21, 2023
b334880
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
138c2ed
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
cef03af
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
90208e1
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
ed79fce
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
2d9627e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
6376c93
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
8f5f7ba
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
dd80bff
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 25, 2023
49d8d0e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 25, 2023
14d2149
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
b3dbb7f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
14a61f7
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
a431c7a
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
556cb95
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
3e10d4b
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
b6a5df0
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
d36548e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
b34b2bd
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
e1782a7
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
246aac4
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
05d7b9e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
e2c22ab
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
9e0ff7f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
147e34d
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
5a4c9f1
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
c7cf1e8
Merge branch 'main' into gh/justinchuby/16/head
justinchuby Apr 28, 2023
1159e42
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 28, 2023
f5971d2
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor:
return cast_inputs(get_type_info, cast, op_schema, *args)


def static_cast_inputs(converter, op_schema: OpSchema, *args):
def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args):
"""Used for autocast during script-translation."""
if op_schema is None:
return args
Expand Down
1 change: 1 addition & 0 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,7 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
self.ir_builder.add_attr_parameter(
self.current_fn,
x.arg,
ta.pytype_to_attrtype(typeinfo),
default_value,
)
self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x)))
Expand Down
6 changes: 2 additions & 4 deletions onnxscript/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_renaming(self):

self.validate_save(renaming, shape_inference=False)

@unittest.skipIf(True, reason="TypeError: val must be numeric not <class 'NoneType'>")
@unittest.skip(reason="TypeError: val must be numeric not <class 'NoneType'>")
def test_opt_output(self):
from onnxscript.tests.models import opt_output

Expand All @@ -280,9 +280,7 @@ def test_opt_input(self):

self.validate_save(opt_input, shape_inference=False)

@unittest.skipIf(
True, reason="ValueError: A function with attributes " "cannot be exported as a model."
)
@unittest.skip("A function with attributes cannot be exported as a model.")
def test_onnxfns2(self):
from onnxscript.tests.models import onnxfns2

Expand Down
126 changes: 83 additions & 43 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# --------------------------------------------------------------------------
from __future__ import annotations

import dataclasses
import io
import logging
import warnings
Expand Down Expand Up @@ -108,9 +109,15 @@ def _opt_var_to_str(x):


class IRAttributeValue:
"""An attribute value (representing an actual parameter)."""
"""An attribute value (representing an actual parameter).

def __init__(self, attrproto) -> None:
Attributes:
name: The name of the attribute.
type: The type of the attribute.
attr_proto: The attribute proto.
"""

def __init__(self, attrproto: onnx.AttributeProto) -> None:
self.attr_proto = attrproto

def __str__(self):
Expand All @@ -120,14 +127,51 @@ def __str__(self):
return helper.printable_attribute(self.attr_proto)

@property
def name(self):
def name(self) -> str:
return self.attr_proto.name

@property
def type(self):
def type(self) -> onnx.AttributeProto.AttributeType:
return self.attr_proto.type


@dataclasses.dataclass(frozen=True)
class IRAttributeParameter:
"""An attribute parameter (representing a formal parameter).

It may or may not carry a default value.

Attributes:
name: The name of the attribute.
type: The type of the attribute.
has_default: Whether the attribute has a default value.
attr_proto: The attribute proto.
"""

name: str
type: onnx.AttributeProto.AttributeType
default_value: str | int | float | None = None

def __str__(self):
if self.has_default:
return helper.printable_attribute(self.attr_proto)
# TODO(justinchuby): Include a readable type name.
return self.name

@property
def has_default(self):
return self.default_value is not None

@property
def attr_proto(self) -> onnx.AttributeProto:
if not self.has_default:
raise ValueError(
"Attribute has no default value. Only attributes with default "
"values can be converted to AttributeProto."
)
return helper.make_attribute(self.name, self.default_value)


class IRStmt:
def __init__(
self,
Expand Down Expand Up @@ -191,9 +235,7 @@ def __init__(self, name: str, domain: str = "") -> None:
self.outputs: list[IRVar] = []
self.stmts: list[IRStmt] = []
# attribute parameters
self.attrs: list[str] = []
# attribute parameters with default value
self.attr_protos: list[IRAttributeValue] = []
self.attrs: list[IRAttributeParameter] = []
self.called_functions: dict[str, onnx.FunctionProto] = {}
self.docstring: str = ""
# a dictionary of nested function-definitions
Expand All @@ -207,11 +249,10 @@ def assigned_names(self) -> Sequence[str]:

def __str__(self):
attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else ""
attr_protos = _format(self.attr_protos, "<", ", ", ">") if self.attr_protos else ""
inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")")
outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")")
stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n")
return f"{self.name} {attrs}{attr_protos}{inputs} => {outputs}{stmts}"
return f"{self.name} {attrs}{inputs} => {outputs}{stmts}"

def append_docstring(self, docstring):
self.docstring += docstring
Expand All @@ -225,11 +266,8 @@ def append_input(self, name: IRVar) -> None:
def append_output(self, name: IRVar) -> None:
self.outputs.append(name)

def add_attr_parameter(self, attr: str | IRAttributeValue) -> None:
if isinstance(attr, IRAttributeValue):
self.attr_protos.append(attr)
else:
self.attrs.append(attr)
def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
self.attrs.append(attr)

def debug_print(self):
if logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -398,19 +436,19 @@ def to_function_proto(self) -> onnx.FunctionProto:
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
]

# attribute_proto is introduced in version onnx==1.13.0.
# attribute_proto is introduced in version onnx==1.14.0.
# If this attribute is available, onnxscript uses it to
# default values for attributes. The function has then two
# lists, one list for attributes without default values,
# another one for attributes with default values.
# If this *attribute_proto* is not available,
# all attributes with a default value are moved to the first
# all attributes are moved to the first
# list, default values are removed.
# TODO: remove this when onnx with attribute_proto is released.
if hasattr(onnx.FunctionProto, "attribute_proto"):
atts = self.attrs
attribute_names = [attr.name for attr in self.attrs if not attr.has_default]
else:
atts = self.attrs + [a.attr_proto.name for a in self.attr_protos]
attribute_names = [attr.name for attr in self.attrs]

f = helper.make_function(
self.domain,
Expand All @@ -419,11 +457,13 @@ def to_function_proto(self) -> onnx.FunctionProto:
outputs=[y.name for y in self.outputs],
nodes=nodes,
opset_imports=opset_imports, # TODO
attributes=atts,
attributes=attribute_names,
doc_string=self.docstring,
)
if hasattr(onnx.FunctionProto, "attribute_proto"):
f.attribute_proto.extend([a.attr_proto for a in self.attr_protos])
f.attribute_proto.extend(
[attr.attr_proto for attr in self.attrs if attr.has_default]
)
return f


Expand All @@ -437,10 +477,10 @@ def __init__(self):
def new_function(self, name: str, domain: str = "", register: bool = False):
if register and (domain, name) in self.functions:
raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.")
fct = IRFunction(name, domain)
function = IRFunction(name, domain)
if register:
self.functions[domain, name] = fct
return fct
self.functions[domain, name] = function
return function

def add_docstring(self, fn: IRFunction, docstring: str):
fn.append_docstring(docstring)
Expand All @@ -454,34 +494,34 @@ def add_stmt(
attrs: Sequence[IRAttributeValue],
sub_functions=None,
) -> None:
s = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
fn.append_stmt(s)
stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
fn.append_stmt(stmt)

def add_input(
self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo
) -> None:
v = IRVar(varname, type, info)
fn.append_input(v)
var = IRVar(varname, type, info)
fn.append_input(var)

def add_attr_parameter(self, fn: IRFunction, varname: str, default_value) -> None:
if default_value is not None:
a = IRAttributeValue(helper.make_attribute(varname, default_value))
fn.add_attr_parameter(a)
else:
fn.add_attr_parameter(varname)
def add_attr_parameter(
self,
fn: IRFunction,
varname: str,
attribute_type: onnx.AttributeProto.AttributeType,
default_value: int | float | str | None,
) -> None:
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))

def add_output(self, fn: IRFunction, varname: str, type, info) -> None:
v = IRVar(varname, type, info)
fn.append_output(v)
def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
var = IRVar(varname, typeinfo, sourceinfo)
fn.append_output(var)

def make_attr(self, attrname: str, attrval: Any) -> IRAttributeValue:
return IRAttributeValue(helper.make_attribute(attrname, attrval))

def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
a = onnx.AttributeProto()
a.name = attrname
a.ref_attr_name = refname
type_ = ta.pytype_to_attrtype(pytype)
assert type_ is not None
a.type = type_
return IRAttributeValue(a)
proto = onnx.AttributeProto()
proto.name = attrname
proto.ref_attr_name = refname
proto.type = ta.pytype_to_attrtype(pytype)
return IRAttributeValue(proto)
Loading