Skip to content

Commit ef8e889

Browse files
authored
[IR] Reconcile graph in Node (#2183)
Always assign a `Graph` object to the node's graph. Fix #2181
1 parent 9d16b89 commit ef8e889

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

onnxscript/ir/_core.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,22 +1145,24 @@ def __init__(
11451145
Args:
11461146
domain: The domain of the operator. For onnx operators, this is an empty string.
11471147
op_type: The name of the operator.
1148-
inputs: The input values. When an input is None, it is an empty input.
1148+
inputs: The input values. When an input is ``None``, it is an empty input.
11491149
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
11501150
overload: The overload name when the node is invoking a function.
11511151
num_outputs: The number of outputs of the node. If not specified, the number is 1.
1152-
outputs: The output values. If None, the outputs are created during initialization.
1153-
version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
1154-
graph: The graph that the node belongs to. If None, the node is not added to any graph.
1155-
A `Node` must belong to zero or one graph.
1156-
name: The name of the node. If None, the node is anonymous.
1152+
outputs: The output values. If ``None``, the outputs are created during initialization.
1153+
version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph.
1154+
graph: The graph that the node belongs to. If ``None``, the node is not added to any graph.
1155+
A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph
1156+
of the function is assigned to the node.
1157+
name: The name of the node. If ``None``, the node is anonymous. The name may be
1158+
set by a :class:`Graph` if ``graph`` is specified.
11571159
doc_string: The documentation string.
11581160
metadata_props: The metadata properties.
11591161
11601162
Raises:
1161-
TypeError: If the attributes are not Attr or RefAttr.
1162-
ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
1163-
ValueError: If an output value is None, when outputs is specified.
1163+
TypeError: If the attributes are not :class:`Attr` or :class:`RefAttr`.
1164+
ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
1165+
ValueError: If an output value is ``None``, when outputs is specified.
11641166
ValueError: If an output value has a producer set already, when outputs is specified.
11651167
"""
11661168
self._name = name
@@ -1187,18 +1189,18 @@ def __init__(
11871189
self._version: int | None = version
11881190
self._metadata: _metadata.MetadataStore | None = None
11891191
self._metadata_props: dict[str, str] | None = metadata_props
1190-
self._graph: Graph | Function | None = graph
1192+
# _graph is set by graph.append
1193+
self._graph: Graph | None = None
1194+
# Add the node to the graph if graph is specified
1195+
if graph is not None:
1196+
graph.append(self)
11911197
self.doc_string = doc_string
11921198

11931199
# Add the node as a use of the inputs
11941200
for i, input_value in enumerate(self._inputs):
11951201
if input_value is not None:
11961202
input_value._add_usage(self, i) # pylint: disable=protected-access
11971203

1198-
# Add the node to the graph if graph is specified
1199-
if self._graph is not None:
1200-
self._graph.append(self)
1201-
12021204
def _create_outputs(
12031205
self, num_outputs: int | None, outputs: Sequence[Value] | None
12041206
) -> tuple[Value, ...]:
@@ -1432,11 +1434,11 @@ def metadata_props(self) -> dict[str, str]:
14321434
return self._metadata_props
14331435

14341436
@property
1435-
def graph(self) -> Graph | Function | None:
1437+
def graph(self) -> Graph | None:
14361438
return self._graph
14371439

14381440
@graph.setter
1439-
def graph(self, value: Graph | Function | None) -> None:
1441+
def graph(self, value: Graph | None) -> None:
14401442
self._graph = value
14411443

14421444
def op_identifier(self) -> _protocols.OperatorIdentifier:
@@ -2178,7 +2180,7 @@ def sort(self) -> None:
21782180
# Obtain all nodes from the graph and its subgraphs for sorting
21792181
nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self))
21802182
# Store the sorted nodes of each subgraph
2181-
sorted_nodes_by_graph: dict[Graph | Function, list[Node]] = {
2183+
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
21822184
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
21832185
}
21842186
# TODO: Explain why we need to store direct predecessors and children and why

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
6767
type=ir.TensorType(tensor.dtype),
6868
const_value=tensor,
6969
)
70-
assert isinstance(node.graph, ir.Graph)
70+
assert node.graph is not None
7171
node.graph.register_initializer(initializer)
72-
# Replace the constant node with the initilizer
72+
# Replace the constant node with the initializer
7373
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
7474
node.graph.remove(node, safe=True)
7575
count += 1

0 commit comments

Comments
 (0)