Skip to content

Commit 59ca12c

Browse files
committed
[IR] Record owning graph for input/output and initializers
1 parent da95c66 commit 59ca12c

File tree

2 files changed

+91
-27
lines changed

2 files changed

+91
-27
lines changed

onnxscript/ir/_core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Generic,
3232
Iterable,
3333
Iterator,
34+
MutableSequence,
3435
NamedTuple,
3536
OrderedDict,
3637
Sequence,
@@ -51,6 +52,7 @@
5152
_name_authority,
5253
_protocols,
5354
_type_casting,
55+
_tracked_containers,
5456
)
5557

5658
if typing.TYPE_CHECKING:
@@ -2116,8 +2118,8 @@ def __init__(
21162118
self.name = name
21172119

21182120
# Private fields that are not to be accessed by any other classes
2119-
self._inputs = list(inputs)
2120-
self._outputs = list(outputs)
2121+
self._inputs = _tracked_containers.GraphInputs(self, inputs)
2122+
self._outputs = _tracked_containers.GraphInputs(self, outputs)
21212123
self._initializers = {}
21222124
for initializer in initializers:
21232125
if isinstance(initializer, str):
@@ -2143,11 +2145,11 @@ def __init__(
21432145
self.extend(nodes)
21442146

21452147
@property
2146-
def inputs(self) -> list[Value]:
2148+
def inputs(self) -> MutableSequence[Value]:
21472149
return self._inputs
21482150

21492151
@property
2150-
def outputs(self) -> list[Value]:
2152+
def outputs(self) -> MutableSequence[Value]:
21512153
return self._outputs
21522154

21532155
@property
Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,47 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""Tracked lists for graph and node IO."""
3+
"""Tracked containers for graph."""
44

55
from __future__ import annotations
66

7+
__all__ = [
8+
"GraphInputs",
9+
"GraphOutputs",
10+
]
11+
712
import collections
8-
from typing import TYPE_CHECKING, Iterable, Literal, SupportsIndex
13+
from typing import TYPE_CHECKING, Iterable, SupportsIndex
14+
15+
import onnxscript
916

1017
if TYPE_CHECKING:
1118
from onnxscript.ir import _core
1219

1320

14-
class GraphIO(collections.UserList[_core.Value]):
21+
class _GraphIO(collections.UserList[_core.Value]):
1522
"""The inputs and outputs of a Graph."""
1623

17-
def __init__(self, graph: _core.Graph, typ: Literal["input", "output"], initlist=None):
24+
def __init__(self, graph: _core.Graph, initlist=None):
1825
super().__init__(initlist)
1926
self._graph = graph
20-
assert typ in {"intput", "output"}
21-
self._typ = typ
27+
28+
def _check_invariance(self) -> None:
29+
"""Check the invariance of the graph."""
30+
raise NotImplementedError
2231

2332
def _set_graph(self, value: _core.Value) -> None:
2433
"""Set the graph for the value."""
25-
if value._graph_input_of is not None and value._graph_input_of is not self._graph:
26-
raise ValueError(
27-
f"Value '{value}' is already an input of a different graph: {value._graph_input_of!r}"
28-
)
29-
if value._graph_output_of is not None and value._graph_output_of is not self._graph:
30-
raise ValueError(
31-
f"Value '{value}' is already an output of a different graph: {value._graph_output_of!r}"
32-
)
33-
34-
if self._typ == "input":
35-
value._graph_input_of = self._graph
36-
else:
37-
value._graph_output_of = self._graph
34+
raise NotImplementedError
3835

3936
def _unset_graph(self, value: _core.Value) -> None:
4037
"""Unset the graph for the value."""
41-
if self._typ == "input":
42-
value._graph_input_of = None
43-
else:
44-
value._graph_output_of = None
38+
raise NotImplementedError
4539

4640
def append(self, item: _core.Value) -> None:
4741
"""Add a new input to the graph."""
4842
super().append(item)
4943
self._set_graph(item)
44+
self._check_invariance()
5045

5146
def extend(self, other) -> None:
5247
"""Extend the list of inputs or outputs."""
@@ -58,17 +53,20 @@ def insert(self, i: int, item: _core.Value) -> None:
5853
"""Insert an input/output to the graph."""
5954
super().insert(i, item)
6055
self._set_graph(item)
56+
self._check_invariance()
6157

6258
def pop(self, i: int = -1) -> _core.Value:
6359
"""Remove an input/output from the graph."""
6460
value = super().pop(i)
6561
self._unset_graph(value)
62+
self._check_invariance()
6663
return value
6764

6865
def remove(self, item: _core.Value) -> None:
6966
"""Remove an input/output from the graph."""
7067
super().remove(item)
7168
self._unset_graph(item)
69+
self._check_invariance()
7270

7371
def clear(self) -> None:
7472
"""Clear the list."""
@@ -85,12 +83,76 @@ def __setitem__(self, i, item) -> None:
8583
for value in item:
8684
self._set_graph(value)
8785
super().__setitem__(i, item)
86+
self._check_invariance()
8887
return
8988
elif isinstance(item, _core.Value) and isinstance(i, SupportsIndex):
9089
# Replace a single item
9190
self._unset_graph(self.data[i])
9291
self._set_graph(item)
9392
super().__setitem__(i, item)
93+
self._check_invariance()
9494
return
9595

9696
raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")
97+
98+
99+
class GraphInputs(_GraphIO):
100+
"""The inputs of a Graph."""
101+
102+
def __init__(self, graph: _core.Graph, initlist=None):
103+
super().__init__(graph, initlist)
104+
105+
def _check_invariance(self) -> None:
106+
"""Check the invariance of the graph."""
107+
if not onnxscript.DEBUG:
108+
return
109+
for value in self.data:
110+
if value._graph_input_of is self._graph:
111+
continue
112+
raise ValueError(
113+
f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}"
114+
)
115+
116+
def _set_graph(self, value: _core.Value) -> None:
117+
"""Set the graph for the value."""
118+
if value._graph_input_of is not None and value._graph_input_of is not self._graph:
119+
raise ValueError(
120+
f"Value '{value}' is already an input of a different graph: {value._graph_input_of!r}"
121+
)
122+
123+
value._graph_input_of = self._graph
124+
125+
def _unset_graph(self, value: _core.Value) -> None:
126+
"""Unset the graph for the value."""
127+
value._graph_input_of = None
128+
129+
130+
class GraphOutputs(_GraphIO):
131+
"""The outputs of a Graph."""
132+
133+
def __init__(self, graph: _core.Graph, initlist=None):
134+
super().__init__(graph, initlist)
135+
136+
def _check_invariance(self) -> None:
137+
"""Check the invariance of the graph."""
138+
if not onnxscript.DEBUG:
139+
return
140+
for value in self.data:
141+
if value._graph_output_of is self._graph:
142+
continue
143+
raise ValueError(
144+
f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}"
145+
)
146+
147+
def _set_graph(self, value: _core.Value) -> None:
148+
"""Set the graph for the value."""
149+
if value._graph_output_of is not None and value._graph_output_of is not self._graph:
150+
raise ValueError(
151+
f"Value '{value}' is already an output of a different graph: {value._graph_output_of!r}"
152+
)
153+
154+
value._graph_output_of = self._graph
155+
156+
def _unset_graph(self, value: _core.Value) -> None:
157+
"""Unset the graph for the value."""
158+
value._graph_output_of = None

0 commit comments

Comments
 (0)