@@ -1145,22 +1145,24 @@ def __init__(
1145
1145
Args:
1146
1146
domain: The domain of the operator. For onnx operators, this is an empty string.
1147
1147
op_type: The name of the operator.
1148
- inputs: The input values. When an input is None, it is an empty input.
1148
+ inputs: The input values. When an input is `` None`` , it is an empty input.
1149
1149
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
1150
1150
overload: The overload name when the node is invoking a function.
1151
1151
num_outputs: The number of outputs of the node. If not specified, the number is 1.
1152
- outputs: The output values. If None, the outputs are created during initialization.
1153
- version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
1154
- graph: The graph that the node belongs to. If None, the node is not added to any graph.
1155
- A `Node` must belong to zero or one graph.
1156
- name: The name of the node. If None, the node is anonymous.
1152
+ outputs: The output values. If ``None``, the outputs are created during initialization.
1153
+ version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph.
1154
+ graph: The graph that the node belongs to. If ``None``, the node is not added to any graph.
1155
+ A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph
1156
+ of the function is assigned to the node.
1157
+ name: The name of the node. If ``None``, the node is anonymous. The name may be
1158
+ set by a :class:`Graph` if ``graph`` is specified.
1157
1159
doc_string: The documentation string.
1158
1160
metadata_props: The metadata properties.
1159
1161
1160
1162
Raises:
1161
- TypeError: If the attributes are not Attr or RefAttr.
1162
- ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
1163
- ValueError: If an output value is None, when outputs is specified.
1163
+ TypeError: If the attributes are not :class:` Attr` or :class:` RefAttr` .
1164
+ ValueError: If `` num_outputs`` , when not `` None`` , is not the same as the length of the outputs.
1165
+ ValueError: If an output value is `` None`` , when outputs is specified.
1164
1166
ValueError: If an output value has a producer set already, when outputs is specified.
1165
1167
"""
1166
1168
self ._name = name
@@ -1187,18 +1189,18 @@ def __init__(
1187
1189
self ._version : int | None = version
1188
1190
self ._metadata : _metadata .MetadataStore | None = None
1189
1191
self ._metadata_props : dict [str , str ] | None = metadata_props
1190
- self ._graph : Graph | Function | None = graph
1192
+ # _graph is set by graph.append
1193
+ self ._graph : Graph | None = None
1194
+ # Add the node to the graph if graph is specified
1195
+ if graph is not None :
1196
+ graph .append (self )
1191
1197
self .doc_string = doc_string
1192
1198
1193
1199
# Add the node as a use of the inputs
1194
1200
for i , input_value in enumerate (self ._inputs ):
1195
1201
if input_value is not None :
1196
1202
input_value ._add_usage (self , i ) # pylint: disable=protected-access
1197
1203
1198
- # Add the node to the graph if graph is specified
1199
- if self ._graph is not None :
1200
- self ._graph .append (self )
1201
-
1202
1204
def _create_outputs (
1203
1205
self , num_outputs : int | None , outputs : Sequence [Value ] | None
1204
1206
) -> tuple [Value , ...]:
@@ -1432,11 +1434,11 @@ def metadata_props(self) -> dict[str, str]:
1432
1434
return self ._metadata_props
1433
1435
1434
1436
@property
1435
- def graph (self ) -> Graph | Function | None :
1437
+ def graph (self ) -> Graph | None :
1436
1438
return self ._graph
1437
1439
1438
1440
@graph .setter
1439
- def graph (self , value : Graph | Function | None ) -> None :
1441
+ def graph (self , value : Graph | None ) -> None :
1440
1442
self ._graph = value
1441
1443
1442
1444
def op_identifier (self ) -> _protocols .OperatorIdentifier :
@@ -2178,7 +2180,7 @@ def sort(self) -> None:
2178
2180
# Obtain all nodes from the graph and its subgraphs for sorting
2179
2181
nodes = list (onnxscript .ir .traversal .RecursiveGraphIterator (self ))
2180
2182
# Store the sorted nodes of each subgraph
2181
- sorted_nodes_by_graph : dict [Graph | Function , list [Node ]] = {
2183
+ sorted_nodes_by_graph : dict [Graph , list [Node ]] = {
2182
2184
graph : [] for graph in {node .graph for node in nodes if node .graph is not None }
2183
2185
}
2184
2186
# TODO: Explain why we need to store direct predecessors and children and why
0 commit comments