Skip to content

Commit f46004e

Browse files
authored
Remove the RefAttr class (#2328)
## Rational We defined the class `RefAttr` in the IR to represent reference attributes in ONNX. Node attributes can be `Attr` and `RefAttr`. However, since most of the time we are working with concrete attributes, the union of types creates a typing situation where we always need to assert the types before taking the values, even if we know a `RefAttr` cannot exist (outside of a function definition). This additionally matches the definition of AttributeProto in ONNX. ## Change This change merged the two classes, and instead defines a `is_ref()` method for users to check the reference attribute. The change is BC breaking for usage like `isinstance(attr, ir.RefAttr)`. Fortunately all such usages exist in this code base and not in PyTorch, so we are safe to complete the change.
1 parent 7aba165 commit f46004e

File tree

13 files changed

+106
-98
lines changed

13 files changed

+106
-98
lines changed

onnxscript/ir/_convenience/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
_protocols.TensorProtocol, # This includes all in-memory tensor types
3333
onnx.TensorProto,
3434
_core.Attr,
35-
_core.RefAttr,
3635
_protocols.GraphProtocol,
3736
Sequence[_protocols.GraphProtocol],
3837
onnx.GraphProto,
@@ -50,7 +49,7 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
5049
return _enums.AttributeType.FLOAT
5150
if isinstance(attr, str):
5251
return _enums.AttributeType.STRING
53-
if isinstance(attr, (_core.Attr, _core.RefAttr)):
52+
if isinstance(attr, _core.Attr):
5453
return attr.type
5554
if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
5655
return _enums.AttributeType.INTS
@@ -97,7 +96,7 @@ def convert_attribute(
9796
name: str,
9897
attr: SupportedAttrTypes,
9998
attr_type: _enums.AttributeType | None = None,
100-
) -> _core.Attr | _core.RefAttr:
99+
) -> _core.Attr:
101100
"""Convert a Python object to a _core.Attr object.
102101
103102
This method is useful when constructing nodes with attributes. It infers the
@@ -121,7 +120,7 @@ def convert_attribute(
121120
raise ValueError("attr_type must be provided when attr is None")
122121
return _core.Attr(name, attr_type, None)
123122

124-
if isinstance(attr, (_core.Attr, _core.RefAttr)):
123+
if isinstance(attr, _core.Attr):
125124
if attr.name != name:
126125
raise ValueError(
127126
f"Attribute name '{attr.name}' does not match provided name '{name}'"
@@ -181,7 +180,7 @@ def convert_attribute(
181180

182181
def convert_attributes(
183182
attrs: Mapping[str, SupportedAttrTypes],
184-
) -> list[_core.Attr | _core.RefAttr]:
183+
) -> list[_core.Attr]:
185184
"""Convert a dictionary of attributes to a list of _core.Attr objects.
186185
187186
It infers the attribute type based on the type of the value. The supported
@@ -247,7 +246,7 @@ def convert_attributes(
247246
Returns:
248247
A list of _core.Attr objects.
249248
"""
250-
attributes: list[_core.Attr | _core.RefAttr] = []
249+
attributes: list[_core.Attr] = []
251250
for name, attr in attrs.items():
252251
if attr is not None:
253252
attributes.append(convert_attribute(name, attr))

onnxscript/ir/_convenience/_constructors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def node(
194194
A node with the given op_type and inputs.
195195
"""
196196
if attributes is None:
197-
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
197+
attrs: Sequence[ir.Attr] = ()
198198
else:
199199
attrs = _convenience.convert_attributes(attributes)
200200
return _core.Node(

onnxscript/ir/_core.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,7 @@ def __init__(
13211321
domain: str,
13221322
op_type: str,
13231323
inputs: Iterable[Value | None],
1324-
attributes: Iterable[Attr | RefAttr] = (),
1324+
attributes: Iterable[Attr] = (),
13251325
*,
13261326
overload: str = "",
13271327
num_outputs: int | None = None,
@@ -1353,7 +1353,7 @@ def __init__(
13531353
metadata_props: The metadata properties.
13541354
13551355
Raises:
1356-
TypeError: If the attributes are not :class:`Attr` or :class:`RefAttr`.
1356+
TypeError: If the attributes are not :class:`Attr`.
13571357
ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
13581358
ValueError: If an output value is ``None``, when outputs is specified.
13591359
ValueError: If an output value has a producer set already, when outputs is specified.
@@ -1368,13 +1368,13 @@ def __init__(
13681368
# Values belong to their defining nodes. The values list is immutable
13691369
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
13701370
attributes = tuple(attributes)
1371-
if attributes and not isinstance(attributes[0], (Attr, RefAttr)):
1371+
if attributes and not isinstance(attributes[0], Attr):
13721372
raise TypeError(
1373-
f"Expected the attributes to be Attr or RefAttr, got {type(attributes[0])}. "
1373+
f"Expected the attributes to be Attr, got {type(attributes[0])}. "
13741374
"If you are copying the attributes from another node, make sure you call "
13751375
"node.attributes.values() because it is a dictionary."
13761376
)
1377-
self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict(
1377+
self._attributes: OrderedDict[str, Attr] = OrderedDict(
13781378
(attr.name, attr) for attr in attributes
13791379
)
13801380
self._overload: str = overload
@@ -1633,7 +1633,7 @@ def outputs(self, _: Sequence[Value]) -> None:
16331633
raise AttributeError("outputs is immutable. Please create a new node instead.")
16341634

16351635
@property
1636-
def attributes(self) -> OrderedDict[str, Attr | RefAttr]:
1636+
def attributes(self) -> OrderedDict[str, Attr]:
16371637
"""The attributes of the node."""
16381638
return self._attributes
16391639

@@ -3106,22 +3106,28 @@ def __repr__(self) -> str:
31063106
return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
31073107

31083108

3109-
class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
3110-
"""Reference attribute."""
3109+
class Attr(
3110+
_protocols.AttributeProtocol,
3111+
_protocols.ReferenceAttributeProtocol,
3112+
_display.PrettyPrintable,
3113+
):
3114+
"""Base class for ONNX attributes or references."""
31113115

3112-
__slots__ = ("_name", "_ref_attr_name", "_type", "doc_string")
3116+
__slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
31133117

31143118
def __init__(
31153119
self,
31163120
name: str,
3117-
ref_attr_name: str,
31183121
type: _enums.AttributeType,
3122+
value: Any,
3123+
ref_attr_name: str | None = None,
31193124
*,
31203125
doc_string: str | None = None,
3121-
) -> None:
3126+
):
31223127
self._name = name
3123-
self._ref_attr_name = ref_attr_name
31243128
self._type = type
3129+
self._value = value
3130+
self._ref_attr_name = ref_attr_name
31253131
self.doc_string = doc_string
31263132

31273133
@property
@@ -3132,43 +3138,21 @@ def name(self) -> str:
31323138
def name(self, value: str) -> None:
31333139
self._name = value
31343140

3135-
@property
3136-
def ref_attr_name(self) -> str:
3137-
return self._ref_attr_name
3138-
3139-
@ref_attr_name.setter
3140-
def ref_attr_name(self, value: str) -> None:
3141-
self._ref_attr_name = value
3142-
31433141
@property
31443142
def type(self) -> _enums.AttributeType:
31453143
return self._type
31463144

3147-
@type.setter
3148-
def type(self, value: _enums.AttributeType) -> None:
3149-
self._type = value
3150-
3151-
def __repr__(self) -> str:
3152-
return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})"
3153-
3154-
3155-
class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
3156-
"""Base class for ONNX attributes."""
3145+
@property
3146+
def value(self) -> Any:
3147+
return self._value
31573148

3158-
__slots__ = ("doc_string", "name", "type", "value")
3149+
@property
3150+
def ref_attr_name(self) -> str | None:
3151+
return self._ref_attr_name
31593152

3160-
def __init__(
3161-
self,
3162-
name: str,
3163-
type: _enums.AttributeType,
3164-
value: Any,
3165-
*,
3166-
doc_string: str | None = None,
3167-
):
3168-
self.name = name
3169-
self.type = type
3170-
self.value = value
3171-
self.doc_string = doc_string
3153+
def is_ref(self) -> bool:
3154+
"""Check if this attribute is a reference attribute."""
3155+
return self.ref_attr_name is not None
31723156

31733157
def __eq__(self, other: object) -> bool:
31743158
if not isinstance(other, _protocols.AttributeProtocol):
@@ -3185,11 +3169,15 @@ def __eq__(self, other: object) -> bool:
31853169
return True
31863170

31873171
def __str__(self) -> str:
3172+
if self.is_ref():
3173+
return f"@{self.ref_attr_name}"
31883174
if self.type == _enums.AttributeType.GRAPH:
31893175
return textwrap.indent("\n" + str(self.value), " " * 4)
31903176
return str(self.value)
31913177

31923178
def __repr__(self) -> str:
3179+
if self.is_ref():
3180+
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, ref_attr_name={self.ref_attr_name!r})"
31933181
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
31943182

31953183
# Well typed getters
@@ -3269,6 +3257,29 @@ def as_graphs(self) -> Sequence[Graph]:
32693257

32703258

32713259
# NOTE: The following functions are just for convenience
3260+
3261+
3262+
def RefAttr(
3263+
name: str,
3264+
ref_attr_name: str,
3265+
type: _enums.AttributeType,
3266+
doc_string: str | None = None,
3267+
) -> Attr:
3268+
"""Create a reference attribute.
3269+
3270+
Args:
3271+
name: The name of the attribute.
3272+
type: The type of the attribute.
3273+
ref_attr_name: The name of the referenced attribute.
3274+
doc_string: Documentation string.
3275+
3276+
Returns:
3277+
A reference attribute.
3278+
"""
3279+
# NOTE: The function name is capitalized to maintain API backward compatibility.
3280+
return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
3281+
3282+
32723283
def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
32733284
"""Create a float attribute."""
32743285
# NOTE: The function name is capitalized to maintain API backward compatibility.

onnxscript/ir/_protocols.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
Collection,
3737
Iterable,
3838
Iterator,
39+
Literal,
3940
Mapping,
4041
MutableMapping,
4142
MutableSequence,
@@ -422,6 +423,8 @@ class AttributeProtocol(Protocol):
422423
value: Any
423424
doc_string: str | None
424425

426+
def is_ref(self) -> Literal[False]: ...
427+
425428

426429
@typing.runtime_checkable
427430
class ReferenceAttributeProtocol(Protocol):
@@ -441,6 +444,8 @@ class ReferenceAttributeProtocol(Protocol):
441444
type: _enums.AttributeType
442445
doc_string: str | None
443446

447+
def is_ref(self) -> Literal[True]: ...
448+
444449

445450
@typing.runtime_checkable
446451
class SparseTensorProtocol(Protocol):

onnxscript/ir/_tape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def op(
8989
output: ir.Value | None = None,
9090
) -> ir.Value:
9191
if attributes is None:
92-
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
92+
attrs: Sequence[ir.Attr] = ()
9393
else:
9494
attrs = _convenience.convert_attributes(attributes)
9595
output_kwargs: dict[str, Any]
@@ -141,7 +141,7 @@ def op_multi_out(
141141
else:
142142
output_kwargs = dict(outputs=outputs)
143143
if attributes is None:
144-
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
144+
attrs: Sequence[ir.Attr] = ()
145145
else:
146146
attrs = _convenience.convert_attributes(attributes)
147147
node = ir.Node(

onnxscript/ir/external_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _all_tensors(
7070
# Look at constant attributes in nodes
7171
for node in _traversal.RecursiveGraphIterator(graph):
7272
for attr in node.attributes.values():
73-
if isinstance(attr, _core.RefAttr):
73+
if attr.is_ref():
7474
continue
7575
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
7676
yield attr.value

onnxscript/ir/passes/common/inliner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class _CopyReplace:
5252
def __init__(
5353
self,
5454
inliner: InlinePass,
55-
attr_map: dict[str, ir.Attr | ir.RefAttr],
55+
attr_map: dict[str, ir.Attr],
5656
value_map: dict[ir.Value, ir.Value | None],
5757
metadata_props: dict[str, str],
5858
call_stack: CallStack,
@@ -83,8 +83,8 @@ def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None:
8383
return None
8484
return self.clone_value(value)
8585

86-
def clone_attr(self, key: str, attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr | None:
87-
if isinstance(attr, ir.Attr):
86+
def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None:
87+
if not attr.is_ref():
8888
if attr.type == ir.AttributeType.GRAPH:
8989
graph = self.clone_graph(attr.as_graph())
9090
return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string)
@@ -94,15 +94,15 @@ def clone_attr(self, key: str, attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAt
9494
key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
9595
)
9696
return attr
97-
assert isinstance(attr, ir.RefAttr)
97+
assert attr.is_ref()
9898
ref_attr_name = attr.ref_attr_name
9999
if ref_attr_name in self._attr_map:
100100
ref_attr = self._attr_map[ref_attr_name]
101-
if isinstance(ref_attr, ir.Attr):
101+
if not ref_attr.is_ref():
102102
return ir.Attr(
103103
key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
104104
)
105-
assert isinstance(ref_attr, ir.RefAttr)
105+
assert ref_attr.ref_attr_name is not None
106106
return ir.RefAttr(
107107
key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
108108
)
@@ -237,7 +237,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
237237
)
238238

239239
# Identify substitutions for both inputs and attributes of the function:
240-
attributes: dict[str, ir.Attr | ir.RefAttr] = node.attributes
240+
attributes: dict[str, ir.Attr] = node.attributes
241241
default_attr_values = {
242242
attr.name: attr
243243
for attr in function.attributes.values()

onnxscript/ir/serde.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,10 @@ def to_proto(ir_object: object) -> object:
234234
return serialize_tensor(ir_object)
235235
if isinstance(ir_object, _protocols.ValueProtocol):
236236
return serialize_value(ir_object)
237-
if isinstance(ir_object, _protocols.AttributeProtocol):
237+
if isinstance(ir_object, _protocols.AttributeProtocol) and not ir_object.is_ref():
238238
return serialize_attribute(ir_object)
239239
if isinstance(ir_object, _protocols.ReferenceAttributeProtocol):
240+
assert ir_object.is_ref()
240241
return serialize_reference_attribute_into(onnx.AttributeProto(), ir_object)
241242
if isinstance(ir_object, _protocols.TypeProtocol):
242243
return serialize_type_into(onnx.TypeProto(), ir_object)
@@ -905,14 +906,14 @@ def deserialize_metadata_props(
905906
_deserialize_string_string_maps = deserialize_metadata_props
906907

907908

908-
def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr:
909+
def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr:
909910
return _deserialize_attribute(proto, [])
910911

911912

912913
@_capture_errors(lambda proto, scoped_values: str(proto))
913914
def _deserialize_attribute(
914915
proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]]
915-
) -> _core.Attr | _core.RefAttr:
916+
) -> _core.Attr:
916917
name = proto.name
917918
doc_string = _get_field(proto, "doc_string")
918919
type_ = _enums.AttributeType(proto.type)
@@ -1465,20 +1466,10 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc
14651466
node_proto.output.append(output.name)
14661467

14671468
for attr in from_.attributes.values():
1468-
if isinstance(attr, _core.Attr):
1469+
if not attr.is_ref():
14691470
serialize_attribute_into(node_proto.attribute.add(), from_=attr)
1470-
elif isinstance(attr, _core.RefAttr):
1471-
serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
1472-
# Handle protocol attributes for completeness. We do not check them first because
1473-
# calling isinstance on a protocol can be slow.
1474-
# Most of the time, we will have Attr or RefAttr so the two branches below
1475-
# will not be taken.
1476-
elif isinstance(attr, _protocols.AttributeProtocol):
1477-
serialize_attribute_into(node_proto.attribute.add(), from_=attr)
1478-
elif isinstance(attr, _protocols.ReferenceAttributeProtocol):
1479-
serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
14801471
else:
1481-
raise TypeError(f"Unsupported attribute type: {type(attr)}")
1472+
serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
14821473

14831474

14841475
def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto:

0 commit comments

Comments
 (0)