Skip to content

Commit 5701cbf

Browse files
authored
Merge attrs and attr_protos in IRFunction | chore(irbuilder) (#625)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #626 * #631 * __->__ #625 Merge the two list in `IRFunction` by changing its type to `IRAttributeParameter`. This way we retain type information for all attributes. It is useful for creating correct `OpSchema`s and `ParamSchema`s in the next PR. Also Include `typeinfo` in `add_attr_parameter`.
1 parent 9bde82d commit 5701cbf

File tree

5 files changed

+113
-77
lines changed

5 files changed

+113
-77
lines changed

onnxscript/converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,7 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
13261326
self.ir_builder.add_attr_parameter(
13271327
self.current_fn,
13281328
x.arg,
1329+
ta.pytype_to_attrtype(typeinfo),
13291330
default_value,
13301331
)
13311332
self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x)))

onnxscript/converter_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_renaming(self):
265265

266266
self.validate_save(renaming, shape_inference=False)
267267

268-
@unittest.skipIf(True, reason="TypeError: val must be numeric not <class 'NoneType'>")
268+
@unittest.skip(reason="TypeError: val must be numeric not <class 'NoneType'>")
269269
def test_opt_output(self):
270270
from onnxscript.tests.models import opt_output
271271

@@ -276,9 +276,7 @@ def test_opt_input(self):
276276

277277
self.validate_save(opt_input, shape_inference=False)
278278

279-
@unittest.skipIf(
280-
True, reason="ValueError: A function with attributes " "cannot be exported as a model."
281-
)
279+
@unittest.skip("A function with attributes cannot be exported as a model.")
282280
def test_onnxfns2(self):
283281
from onnxscript.tests.models import onnxfns2
284282

onnxscript/irbuilder.py

Lines changed: 86 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# --------------------------------------------------------------------------
55
from __future__ import annotations
66

7+
import dataclasses
78
import io
89
import logging
910
import warnings
@@ -108,9 +109,15 @@ def _opt_var_to_str(x):
108109

109110

110111
class IRAttributeValue:
111-
"""An attribute value (representing an actual parameter)."""
112+
"""An attribute value (representing an actual parameter).
112113
113-
def __init__(self, attrproto) -> None:
114+
Attributes:
115+
name: The name of the attribute.
116+
type: The type of the attribute.
117+
attr_proto: The attribute proto.
118+
"""
119+
120+
def __init__(self, attrproto: onnx.AttributeProto) -> None:
114121
self.attr_proto = attrproto
115122

116123
def __str__(self):
@@ -120,14 +127,54 @@ def __str__(self):
120127
return helper.printable_attribute(self.attr_proto)
121128

122129
@property
123-
def name(self):
130+
def name(self) -> str:
124131
return self.attr_proto.name
125132

126133
@property
127-
def type(self):
134+
def type(self) -> onnx.AttributeProto.AttributeType:
128135
return self.attr_proto.type
129136

130137

138+
@dataclasses.dataclass(frozen=True)
139+
class IRAttributeParameter:
140+
"""An attribute parameter (representing a formal parameter).
141+
142+
It may or may not carry a default value.
143+
144+
Attributes:
145+
name: The name of the attribute.
146+
type: The type of the attribute.
147+
default_value: The default value of the attribute.
148+
has_default: Whether the attribute has a default value.
149+
attr_proto: The attribute proto.
150+
"""
151+
152+
name: str
153+
type: onnx.AttributeProto.AttributeType
154+
default_value: str | int | float | None = None
155+
156+
# TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.
157+
158+
def __str__(self):
159+
if self.has_default:
160+
return helper.printable_attribute(self.attr_proto)
161+
# TODO(justinchuby): Include a readable type name.
162+
return self.name
163+
164+
@property
165+
def has_default(self):
166+
return self.default_value is not None
167+
168+
@property
169+
def attr_proto(self) -> onnx.AttributeProto:
170+
if not self.has_default:
171+
raise ValueError(
172+
"Attribute has no default value. Only attributes with default "
173+
"values can be converted to AttributeProto."
174+
)
175+
return helper.make_attribute(self.name, self.default_value)
176+
177+
131178
class IRStmt:
132179
def __init__(
133180
self,
@@ -191,9 +238,7 @@ def __init__(self, name: str, domain: str = "") -> None:
191238
self.outputs: list[IRVar] = []
192239
self.stmts: list[IRStmt] = []
193240
# attribute parameters
194-
self.attrs: list[str] = []
195-
# attribute parameters with default value
196-
self.attr_protos: list[IRAttributeValue] = []
241+
self.attrs: list[IRAttributeParameter] = []
197242
self.called_functions: dict[str, onnx.FunctionProto] = {}
198243
self.docstring: str = ""
199244
# a dictionary of nested function-definitions
@@ -207,11 +252,10 @@ def assigned_names(self) -> Sequence[str]:
207252

208253
def __str__(self):
209254
attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else ""
210-
attr_protos = _format(self.attr_protos, "<", ", ", ">") if self.attr_protos else ""
211255
inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")")
212256
outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")")
213257
stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n")
214-
return f"{self.name} {attrs}{attr_protos}{inputs} => {outputs}{stmts}"
258+
return f"{self.name} {attrs}{inputs} => {outputs}{stmts}"
215259

216260
def append_docstring(self, docstring):
217261
self.docstring += docstring
@@ -225,11 +269,8 @@ def append_input(self, name: IRVar) -> None:
225269
def append_output(self, name: IRVar) -> None:
226270
self.outputs.append(name)
227271

228-
def add_attr_parameter(self, attr: str | IRAttributeValue) -> None:
229-
if isinstance(attr, IRAttributeValue):
230-
self.attr_protos.append(attr)
231-
else:
232-
self.attrs.append(attr)
272+
def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
273+
self.attrs.append(attr)
233274

234275
def debug_print(self):
235276
if logger.isEnabledFor(logging.DEBUG):
@@ -398,19 +439,19 @@ def to_function_proto(self) -> onnx.FunctionProto:
398439
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
399440
]
400441

401-
# attribute_proto is introduced in version onnx==1.13.0.
442+
# attribute_proto is introduced in version onnx==1.14.0.
402443
# If this attribute is available, onnxscript uses it to
403444
# default values for attributes. The function has then two
404445
# lists, one list for attributes without default values,
405446
# another one for attributes with default values.
406447
# If this *attribute_proto* is not available,
407-
# all attributes with a default value are moved to the first
448+
# all attributes are moved to the first
408449
# list, default values are removed.
409450
# TODO: remove this when onnx with attribute_proto is released.
410451
if hasattr(onnx.FunctionProto, "attribute_proto"):
411-
atts = self.attrs
452+
attribute_names = [attr.name for attr in self.attrs if not attr.has_default]
412453
else:
413-
atts = self.attrs + [a.attr_proto.name for a in self.attr_protos]
454+
attribute_names = [attr.name for attr in self.attrs]
414455

415456
f = helper.make_function(
416457
self.domain,
@@ -419,11 +460,13 @@ def to_function_proto(self) -> onnx.FunctionProto:
419460
outputs=[y.name for y in self.outputs],
420461
nodes=nodes,
421462
opset_imports=opset_imports, # TODO
422-
attributes=atts,
463+
attributes=attribute_names,
423464
doc_string=self.docstring,
424465
)
425466
if hasattr(onnx.FunctionProto, "attribute_proto"):
426-
f.attribute_proto.extend([a.attr_proto for a in self.attr_protos])
467+
f.attribute_proto.extend(
468+
[attr.attr_proto for attr in self.attrs if attr.has_default]
469+
)
427470
return f
428471

429472

@@ -437,10 +480,10 @@ def __init__(self):
437480
def new_function(self, name: str, domain: str = "", register: bool = False):
438481
if register and (domain, name) in self.functions:
439482
raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.")
440-
fct = IRFunction(name, domain)
483+
function = IRFunction(name, domain)
441484
if register:
442-
self.functions[domain, name] = fct
443-
return fct
485+
self.functions[domain, name] = function
486+
return function
444487

445488
def add_docstring(self, fn: IRFunction, docstring: str):
446489
fn.append_docstring(docstring)
@@ -454,34 +497,34 @@ def add_stmt(
454497
attrs: Sequence[IRAttributeValue],
455498
sub_functions=None,
456499
) -> None:
457-
s = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
458-
fn.append_stmt(s)
500+
stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
501+
fn.append_stmt(stmt)
459502

460503
def add_input(
461504
self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo
462505
) -> None:
463-
v = IRVar(varname, type, info)
464-
fn.append_input(v)
506+
var = IRVar(varname, type, info)
507+
fn.append_input(var)
465508

466-
def add_attr_parameter(self, fn: IRFunction, varname: str, default_value) -> None:
467-
if default_value is not None:
468-
a = IRAttributeValue(helper.make_attribute(varname, default_value))
469-
fn.add_attr_parameter(a)
470-
else:
471-
fn.add_attr_parameter(varname)
509+
def add_attr_parameter(
510+
self,
511+
fn: IRFunction,
512+
varname: str,
513+
attribute_type: onnx.AttributeProto.AttributeType,
514+
default_value: int | float | str | None,
515+
) -> None:
516+
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))
472517

473-
def add_output(self, fn: IRFunction, varname: str, type, info) -> None:
474-
v = IRVar(varname, type, info)
475-
fn.append_output(v)
518+
def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
519+
var = IRVar(varname, typeinfo, sourceinfo)
520+
fn.append_output(var)
476521

477522
def make_attr(self, attrname: str, attrval: Any) -> IRAttributeValue:
478523
return IRAttributeValue(helper.make_attribute(attrname, attrval))
479524

480525
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
481-
a = onnx.AttributeProto()
482-
a.name = attrname
483-
a.ref_attr_name = refname
484-
type_ = ta.pytype_to_attrtype(pytype)
485-
assert type_ is not None
486-
a.type = type_
487-
return IRAttributeValue(a)
526+
proto = onnx.AttributeProto()
527+
proto.name = attrname
528+
proto.ref_attr_name = refname
529+
proto.type = ta.pytype_to_attrtype(pytype)
530+
return IRAttributeValue(proto)

onnxscript/type_annotation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
5959

6060
def pytype_to_attrtype(
6161
pytype: TypeAnnotationValue,
62-
) -> typing.Optional[onnx.AttributeProto.AttributeType]:
62+
) -> onnx.AttributeProto.AttributeType:
6363
pytype = _remove_annotation(pytype)
6464
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
6565
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
@@ -74,7 +74,7 @@ def pytype_to_attrtype(
7474
elt_type = get_args(pytype)[0]
7575
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
7676
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
77-
return None
77+
return onnx.AttributeProto.UNDEFINED
7878

7979

8080
def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool:

onnxscript/values.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# --------------------------------------------------------------------------
55
from __future__ import annotations
66

7-
import collections
87
import dataclasses
98
import logging
109
import types
@@ -311,40 +310,31 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
311310
# The first len(func_ir.inputs) arguments are onnx inputs
312311
inputs = function_ir.inputs
313312
# The rest is onnx attributes
314-
attributes = function_ir.attrs
315-
# Construct a dictionary of attributes with their names specified in the function
316-
# definition
317-
attr_name_to_protos = collections.OrderedDict(
318-
(attr.name, attr) for attr in function_ir.attr_protos
319-
)
320-
321-
# args with default value are attributes
313+
322314
schemas = []
323315
for arg in inputs:
324316
if isinstance(arg.typeinfo, onnx.TypeProto.Optional):
325317
required = False
326318
else:
327319
required = True
328-
param_schema = ParamSchema(
329-
name=arg.name, type=arg.typeinfo, is_input=True, required=required
320+
schemas.append(
321+
ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required)
330322
)
331-
schemas.append(param_schema)
332-
333-
for attr_name in attributes:
334-
# Attributes without default values
335-
# FIXME(justinchuby): Where can we find the type?
336-
param_schema = ParamSchema(name=attr_name, type=None, is_input=False)
337-
schemas.append(param_schema)
338323

339-
for name, attr_value in attr_name_to_protos.items():
340-
param_schema = ParamSchema(
341-
name=name,
342-
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type],
343-
default=_get_attribute_value(attr_value.attr_proto),
344-
is_input=False,
345-
# All function attributes are required
324+
for attr_parameter in function_ir.attrs:
325+
schemas.append(
326+
ParamSchema(
327+
name=attr_parameter.name,
328+
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get(
329+
onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg]
330+
),
331+
default=_EmptyDefault
332+
if attr_parameter.default_value is None
333+
else attr_parameter.default_value,
334+
is_input=False,
335+
required=not attr_parameter.has_default,
336+
)
346337
)
347-
schemas.append(param_schema)
348338

349339
self._param_schemas = tuple(schemas)
350340
return self._param_schemas # type: ignore[return-value]
@@ -355,8 +345,12 @@ def to_function_proto(self):
355345

356346
def to_model_proto(self, **kwargs):
357347
"""Converts the function into :class:`onnx.ModelProto`."""
358-
if self.function_ir.attrs:
359-
raise ValueError("A function with attributes cannot be exported as a model.")
348+
if self.function_ir.attrs and any(
349+
not attr.has_default for attr in self.function_ir.attrs
350+
):
351+
raise ValueError(
352+
"A function with required attributes cannot be exported as a model."
353+
)
360354
# Note: The function must also have monomorphic type annotation for inputs/outputs
361355
# to be converted into a valid model. Otherwise, we can still produce an ONNX
362356
# model, but it will not pass the ONNX model checker. We do not report an error

0 commit comments

Comments
 (0)