Skip to content

Commit e8eb966

Browse files
committed
Merge attrs and attr_protos in IRFunction | chore(irbuilder)
ghstack-source-id: 3582671 Pull Request resolved: #625
1 parent 52d2036 commit e8eb966

File tree

4 files changed

+64
-48
lines changed

4 files changed

+64
-48
lines changed

onnxscript/converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,7 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
13261326
self.ir_builder.add_attr_parameter(
13271327
self.current_fn,
13281328
x.arg,
1329+
ta.pytype_to_attrtype(typeinfo),
13291330
default_value,
13301331
)
13311332
self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x)))

onnxscript/irbuilder.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,16 @@ def _opt_var_to_str(x):
108108

109109

110110
class IRAttributeValue:
111-
"""An attribute value (representing an actual parameter)."""
111+
"""An attribute value (representing an actual parameter).
112112
113-
def __init__(self, attrproto) -> None:
113+
Attributes:
114+
attr_proto: The attribute proto
115+
has_default: Whether the attribute has a default value.
116+
"""
117+
118+
def __init__(self, attrproto, has_default: bool) -> None:
114119
self.attr_proto = attrproto
120+
self.has_default = has_default
115121

116122
def __str__(self):
117123
if self.attr_proto.HasField("ref_attr_name"):
@@ -191,9 +197,7 @@ def __init__(self, name: str, domain: str = "") -> None:
191197
self.outputs: list[IRVar] = []
192198
self.stmts: list[IRStmt] = []
193199
# attribute parameters
194-
self.attrs: list[str] = []
195-
# attribute parameters with default value
196-
self.attr_protos: list[IRAttributeValue] = []
200+
self.attrs: list[IRAttributeValue] = []
197201
self.called_functions: dict[str, onnx.FunctionProto] = {}
198202
self.docstring: str = ""
199203
# a dictionary of nested function-definitions
@@ -207,11 +211,10 @@ def assigned_names(self) -> Sequence[str]:
207211

208212
def __str__(self):
209213
attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else ""
210-
attr_protos = _format(self.attr_protos, "<", ", ", ">") if self.attr_protos else ""
211214
inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")")
212215
outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")")
213216
stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n")
214-
return f"{self.name} {attrs}{attr_protos}{inputs} => {outputs}{stmts}"
217+
return f"{self.name} {attrs}{inputs} => {outputs}{stmts}"
215218

216219
def append_docstring(self, docstring):
217220
self.docstring += docstring
@@ -225,11 +228,8 @@ def append_input(self, name: IRVar) -> None:
225228
def append_output(self, name: IRVar) -> None:
226229
self.outputs.append(name)
227230

228-
def add_attr_parameter(self, attr: str | IRAttributeValue) -> None:
229-
if isinstance(attr, IRAttributeValue):
230-
self.attr_protos.append(attr)
231-
else:
232-
self.attrs.append(attr)
231+
def add_attr_parameter(self, attr: IRAttributeValue) -> None:
232+
self.attrs.append(attr)
233233

234234
def debug_print(self):
235235
if logger.isEnabledFor(logging.DEBUG):
@@ -398,19 +398,19 @@ def to_function_proto(self) -> onnx.FunctionProto:
398398
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
399399
]
400400

401-
# attribute_proto is introduced in version onnx==1.13.0.
401+
# attribute_proto is introduced in version onnx==1.14.0.
402402
# If this attribute is available, onnxscript uses it to
403403
# default values for attributes. The function has then two
404404
# lists, one list for attributes without default values,
405405
# another one for attributes with default values.
406406
# If this *attribute_proto* is not available,
407-
# all attributes with a default value are moved to the first
407+
# all attributes are moved to the first
408408
# list, default values are removed.
409409
# TODO: remove this when onnx with attribute_proto is released.
410410
if hasattr(onnx.FunctionProto, "attribute_proto"):
411-
atts = self.attrs
411+
attribute_names = [attr.name for attr in self.attrs if not attr.has_default]
412412
else:
413-
atts = self.attrs + [a.attr_proto.name for a in self.attr_protos]
413+
attribute_names = [attr.name for attr in self.attrs]
414414

415415
f = helper.make_function(
416416
self.domain,
@@ -419,11 +419,13 @@ def to_function_proto(self) -> onnx.FunctionProto:
419419
outputs=[y.name for y in self.outputs],
420420
nodes=nodes,
421421
opset_imports=opset_imports, # TODO
422-
attributes=atts,
422+
attributes=attribute_names,
423423
doc_string=self.docstring,
424424
)
425425
if hasattr(onnx.FunctionProto, "attribute_proto"):
426-
f.attribute_proto.extend([a.attr_proto for a in self.attr_protos])
426+
f.attribute_proto.extend(
427+
[attr.attr_proto for attr in self.attrs if attr.has_default]
428+
)
427429
return f
428430

429431

@@ -463,25 +465,35 @@ def add_input(
463465
v = IRVar(varname, type, info)
464466
fn.append_input(v)
465467

466-
def add_attr_parameter(self, fn: IRFunction, varname: str, default_value) -> None:
468+
def add_attr_parameter(
469+
self,
470+
fn: IRFunction,
471+
varname: str,
472+
attribute_type: onnx.AttributeProto.AttributeType,
473+
default_value,
474+
) -> None:
467475
if default_value is not None:
468-
a = IRAttributeValue(helper.make_attribute(varname, default_value))
469-
fn.add_attr_parameter(a)
476+
fn.add_attr_parameter(
477+
IRAttributeValue(
478+
helper.make_attribute(varname, default_value), has_default=True
479+
)
480+
)
470481
else:
471-
fn.add_attr_parameter(varname)
482+
proto = onnx.AttributeProto()
483+
proto.name = varname
484+
proto.type = attribute_type
485+
fn.add_attr_parameter(IRAttributeValue(proto, has_default=False))
472486

473487
def add_output(self, fn: IRFunction, varname: str, type, info) -> None:
474488
v = IRVar(varname, type, info)
475489
fn.append_output(v)
476490

477491
def make_attr(self, attrname: str, attrval: Any) -> IRAttributeValue:
478-
return IRAttributeValue(helper.make_attribute(attrname, attrval))
492+
return IRAttributeValue(helper.make_attribute(attrname, attrval), has_default=True)
479493

480494
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
481495
a = onnx.AttributeProto()
482496
a.name = attrname
483497
a.ref_attr_name = refname
484-
type_ = ta.pytype_to_attrtype(pytype)
485-
assert type_ is not None
486-
a.type = type_
487-
return IRAttributeValue(a)
498+
a.type = ta.pytype_to_attrtype(pytype)
499+
return IRAttributeValue(a, has_default=False)

onnxscript/type_annotation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
5959

6060
def pytype_to_attrtype(
6161
pytype: TypeAnnotationValue,
62-
) -> typing.Optional[onnx.AttributeProto.AttributeType]:
62+
) -> onnx.AttributeProto.AttributeType:
6363
pytype = _remove_annotation(pytype)
6464
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
6565
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
@@ -74,7 +74,7 @@ def pytype_to_attrtype(
7474
elt_type = get_args(pytype)[0]
7575
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
7676
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
77-
return None
77+
return onnx.AttributeProto.UNDEFINED
7878

7979

8080
def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool:

onnxscript/values.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,10 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
311311
# The first len(func_ir.inputs) arguments are onnx inputs
312312
inputs = function_ir.inputs
313313
# The rest is onnx attributes
314-
attributes = function_ir.attrs
315314
# Construct a dictionary of attributes with their names specified in the function
316315
# definition
317316
attr_name_to_protos = collections.OrderedDict(
318-
(attr.name, attr) for attr in function_ir.attr_protos
317+
(attr.name, attr) for attr in function_ir.attrs
319318
)
320319

321320
# args with default value are attributes
@@ -325,26 +324,30 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
325324
required = False
326325
else:
327326
required = True
328-
param_schema = ParamSchema(
329-
name=arg.name, type=arg.typeinfo, is_input=True, required=required
327+
schemas.append(
328+
ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required)
330329
)
331-
schemas.append(param_schema)
332-
333-
for attr_name in attributes:
334-
# Attributes without default values
335-
# FIXME(justinchuby): Where can we find the type?
336-
param_schema = ParamSchema(name=attr_name, type=None, is_input=False)
337-
schemas.append(param_schema)
338330

339331
for name, attr_value in attr_name_to_protos.items():
340-
param_schema = ParamSchema(
341-
name=name,
342-
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type],
343-
default=_get_attribute_value(attr_value.attr_proto),
344-
is_input=False,
345-
# All function attributes are required
346-
)
347-
schemas.append(param_schema)
332+
if not attr_value.has_default:
333+
schemas.append(
334+
ParamSchema(
335+
name=name,
336+
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type],
337+
is_input=False,
338+
required=True,
339+
)
340+
)
341+
else:
342+
schemas.append(
343+
ParamSchema(
344+
name=name,
345+
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type],
346+
default=_get_attribute_value(attr_value.attr_proto),
347+
is_input=False,
348+
required=True,
349+
)
350+
)
348351

349352
self._param_schemas = tuple(schemas)
350353
return self._param_schemas # type: ignore[return-value]

0 commit comments

Comments
 (0)