-
Notifications
You must be signed in to change notification settings - Fork 72
Auto generate OpSchema for functions #594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a3ec431
0224ad2
fc718d7
7240cd4
5febf7f
44cd785
9e6e8ba
49b7123
8c7e447
1e058ae
f859b8e
b591aea
1c525b0
15381e4
8bf732e
2c9c90e
544f1f7
4334edf
7fcd2e0
ed84d6f
d2a35bd
4984321
c1bdc51
fec676b
7207107
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,7 +191,7 @@ def __init__(self, name: str, domain: str = "") -> None: | |
self.outputs: list[IRVar] = [] | ||
self.stmts: list[IRStmt] = [] | ||
# attribute parameters | ||
self.attrs: list[str] = [] | ||
self.attrs: list[IRAttributeValue] = [] | ||
# attribute parameters with default value | ||
self.attr_protos: list[IRAttributeValue] = [] | ||
self.called_functions: dict[str, onnx.FunctionProto] = {} | ||
|
@@ -225,8 +225,8 @@ def append_input(self, name: IRVar) -> None: | |
def append_output(self, name: IRVar) -> None: | ||
self.outputs.append(name) | ||
|
||
def add_attr_parameter(self, attr: str | IRAttributeValue) -> None: | ||
if isinstance(attr, IRAttributeValue): | ||
def add_attr_parameter(self, attr: IRAttributeValue, has_default: bool) -> None: | ||
if has_default: | ||
self.attr_protos.append(attr) | ||
else: | ||
self.attrs.append(attr) | ||
|
@@ -398,7 +398,7 @@ def to_function_proto(self) -> onnx.FunctionProto: | |
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() | ||
] | ||
|
||
# attribute_proto is introduced in version onnx==1.13.0. | ||
# attribute_proto is introduced in version onnx==1.14.0. | ||
# If this attribute is available, onnxscript uses it to | ||
# default values for attributes. The function has then two | ||
# lists, one list for attributes without default values, | ||
|
@@ -408,9 +408,12 @@ def to_function_proto(self) -> onnx.FunctionProto: | |
# list, default values are removed. | ||
# TODO: remove this when onnx with attribute_proto is released. | ||
if hasattr(onnx.FunctionProto, "attribute_proto"): | ||
atts = self.attrs | ||
attributes = [attr.name for attr in self.attrs] | ||
else: | ||
atts = self.attrs + [a.attr_proto.name for a in self.attr_protos] | ||
attributes = [ | ||
*[attr.name for attr in self.attrs], | ||
*[a.attr_proto.name for a in self.attr_protos], | ||
] | ||
|
||
f = helper.make_function( | ||
self.domain, | ||
|
@@ -419,11 +422,11 @@ def to_function_proto(self) -> onnx.FunctionProto: | |
outputs=[y.name for y in self.outputs], | ||
nodes=nodes, | ||
opset_imports=opset_imports, # TODO | ||
attributes=atts, | ||
attributes=attributes, | ||
doc_string=self.docstring, | ||
) | ||
if hasattr(onnx.FunctionProto, "attribute_proto"): | ||
f.attribute_proto.extend([a.attr_proto for a in self.attr_protos]) | ||
f.attribute_proto.extend([attr.attr_proto for attr in self.attr_protos]) | ||
return f | ||
|
||
|
||
|
@@ -463,12 +466,23 @@ def add_input( | |
v = IRVar(varname, type, info) | ||
fn.append_input(v) | ||
|
||
def add_attr_parameter(self, fn: IRFunction, varname: str, default_value) -> None: | ||
def add_attr_parameter( | ||
self, | ||
fn: IRFunction, | ||
varname: str, | ||
attribute_type: onnx.AttributeProto.AttributeType, | ||
default_value: Any, | ||
) -> None: | ||
if default_value is not None: | ||
a = IRAttributeValue(helper.make_attribute(varname, default_value)) | ||
fn.add_attr_parameter(a) | ||
fn.add_attr_parameter( | ||
IRAttributeValue(helper.make_attribute(varname, default_value)), | ||
has_default=True, | ||
) | ||
else: | ||
fn.add_attr_parameter(varname) | ||
proto = onnx.AttributeProto() | ||
proto.name = varname | ||
proto.type = attribute_type | ||
fn.add_attr_parameter(IRAttributeValue(proto), has_default=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think has_defualt can be an attribute of IRAttributeValue? I think that's more straightforward. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a great suggestion - I thought of this too. Will make an update |
||
|
||
def add_output(self, fn: IRFunction, varname: str, type, info) -> None: | ||
v = IRVar(varname, type, info) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,69 +105,105 @@ def to_type_proto(cls) -> onnx.TypeProto: | |
shape = [cls.shape] # example: "FLOAT[10]" | ||
return onnx.helper.make_tensor_type_proto(cls.dtype, shape) | ||
|
||
@classmethod | ||
def to_string(cls) -> str: | ||
raise NotImplementedError() | ||
|
||
|
||
class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(float)" | ||
|
||
|
||
class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(uint8)" | ||
|
||
|
||
class INT8(TensorType, dtype=onnx.TensorProto.INT8): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(int8)" | ||
|
||
|
||
class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(uint16)" | ||
|
||
|
||
class INT16(TensorType, dtype=onnx.TensorProto.INT16): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(int16)" | ||
|
||
|
||
class INT32(TensorType, dtype=onnx.TensorProto.INT32): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(int32)" | ||
|
||
|
||
class INT64(TensorType, dtype=onnx.TensorProto.INT64): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(int64)" | ||
|
||
|
||
class STRING(TensorType, dtype=onnx.TensorProto.STRING): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(string)" | ||
|
||
|
||
class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(bool)" | ||
|
||
|
||
class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(float16)" | ||
|
||
|
||
class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(double)" | ||
|
||
|
||
class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(uint32)" | ||
|
||
|
||
class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(uint64)" | ||
|
||
|
||
class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(complex64)" | ||
|
||
|
||
class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(complex128)" | ||
|
||
|
||
class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): | ||
pass | ||
@classmethod | ||
def to_string(cls): | ||
return "tensor(bfloat16)" | ||
|
||
|
||
def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: | ||
|
@@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: | |
|
||
# Currently, only tensor types are supported. Need to expand support for other ONNX types. | ||
ONNXType = TensorType | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure. Maybe? |
||
|
||
ALL_TENSOR_TYPES = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please sort them in alphabetical order for better readness? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
FLOAT, | ||
UINT8, | ||
INT8, | ||
UINT16, | ||
INT16, | ||
INT32, | ||
INT64, | ||
STRING, | ||
BOOL, | ||
FLOAT16, | ||
DOUBLE, | ||
UINT32, | ||
UINT64, | ||
COMPLEX64, | ||
COMPLEX128, | ||
BFLOAT16, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better annotation we can use than Any?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Let me refine that