Skip to content

Commit da95c66

Browse files
committed
GraphIO
1 parent 9c3f580 commit da95c66

File tree

2 files changed

+89
-170
lines changed

2 files changed

+89
-170
lines changed

onnxscript/ir/_core.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,9 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
17571757

17581758
__slots__ = (
17591759
"_const_value",
1760+
"_graph_initializer_of",
1761+
"_graph_input_of",
1762+
"_graph_output_of",
17601763
"_index",
17611764
"_metadata",
17621765
"_metadata_props",
@@ -1808,6 +1811,13 @@ def __init__(
18081811
self._uses: dict[Usage, None] = {}
18091812
self.doc_string = doc_string
18101813

1814+
# The graph this value belongs to. It is set *only* when the value is added as
1815+
# a graph input, graph output, or initializer.
1816+
# The three properties can only be set by the Graph class (GraphIO).
1817+
self._graph_initializer_of: Graph | None = None
1818+
self._graph_input_of: Graph | None = None
1819+
self._graph_output_of: Graph | None = None
1820+
18111821
def __repr__(self) -> str:
18121822
value_name = self.name if self.name else "anonymous:" + str(id(self))
18131823
type_text = f", type={self.type!r}" if self.type is not None else ""
@@ -1986,15 +1996,17 @@ def metadata_props(self) -> dict[str, str]:
19861996
self._metadata_props = {}
19871997
return self._metadata_props
19881998

1999+
def is_initializer(self) -> bool:
2000+
"""Whether the value is an initializer."""
2001+
return self._graph_initializer_of is not None
2002+
2003+
def is_graph_input(self) -> bool:
2004+
"""Whether the value is an input of a graph."""
2005+
return self._graph_input_of is not None
2006+
19892007
def is_graph_output(self) -> bool:
19902008
"""Whether the value is an output of a graph."""
1991-
if (producer := self.producer()) is None:
1992-
return False
1993-
if (graph := producer.graph) is None:
1994-
return False
1995-
# Cannot use `in` because __eq__ may be defined by subclasses, even though
1996-
# it is not recommended
1997-
return any(output is self for output in graph.outputs)
2009+
return self._graph_output_of is not None
19982010

19992011

20002012
def Input(

onnxscript/ir/_tracked_lists.py

Lines changed: 70 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -5,185 +5,92 @@
55
from __future__ import annotations
66

77
import collections
8-
import contextlib
9-
from typing import TYPE_CHECKING, NoReturn
8+
from typing import TYPE_CHECKING, Iterable, Literal, SupportsIndex
109

1110
if TYPE_CHECKING:
1211
from onnxscript.ir import _core
1312

1413

15-
@contextlib.contextmanager
16-
def _update_value_usages(node: _core.Node, inputs: list[_core.Value | None]):
17-
"""Temporarily unset usages of the inputs."""
18-
for i, item in enumerate(inputs):
19-
if item is not None:
20-
item._remove_usage(node, i) # pylint: disable=protected-access
21-
try:
22-
# Caller will modify the inputs
23-
yield
24-
finally:
25-
for i, item in enumerate(inputs):
26-
if item is not None:
27-
item._add_usage(node, i) # pylint: disable=protected-access
14+
class GraphIO(collections.UserList[_core.Value]):
15+
"""The inputs and outputs of a Graph."""
2816

29-
30-
class NodeInputs(collections.UserList[_core.Value | None]):
31-
def __init__(self, node: _core.Node, initlist=None):
17+
def __init__(self, graph: _core.Graph, typ: Literal["input", "output"], initlist=None):
3218
super().__init__(initlist)
33-
self._node = node
34-
35-
def prepend(self, item: _core.Value | None) -> None:
36-
"""Add a new input to the node."""
37-
self.insert(0, item)
38-
39-
def append(self, item: _core.Value | None) -> None:
40-
"""Add a new input to the node."""
41-
index = len(self.data)
42-
if item is not None:
43-
item._add_usage(self._node, index) # pylint: disable=protected-access
44-
self.data.append(item)
45-
46-
def extend(self, other) -> None:
47-
for item in other:
48-
self.append(item)
49-
50-
def insert(self, i: int, item: _core.Value | None) -> None:
51-
with _update_value_usages(self._node, self.data):
52-
self.data.insert(i, item)
19+
self._graph = graph
20+
assert typ in {"intput", "output"}
21+
self._typ = typ
22+
23+
def _set_graph(self, value: _core.Value) -> None:
24+
"""Set the graph for the value."""
25+
if value._graph_input_of is not None and value._graph_input_of is not self._graph:
26+
raise ValueError(
27+
f"Value '{value}' is already an input of a different graph: {value._graph_input_of!r}"
28+
)
29+
if value._graph_output_of is not None and value._graph_output_of is not self._graph:
30+
raise ValueError(
31+
f"Value '{value}' is already an output of a different graph: {value._graph_output_of!r}"
32+
)
5333

54-
def pop(self, i: int = -1) -> _core.Value | None:
55-
item = self.data[i]
56-
if i == -1:
57-
# Remove the last item. No usages need to be updated
58-
if item is not None:
59-
item._remove_usage(self._node, i)
60-
return self.data.pop()
61-
# Otherwise we need to update usages
62-
with _update_value_usages(self._node, self.data):
63-
result = self.data.pop(i)
34+
if self._typ == "input":
35+
value._graph_input_of = self._graph
36+
else:
37+
value._graph_output_of = self._graph
6438

65-
return result
66-
67-
def clear(self) -> None:
68-
"""Clear the list."""
69-
for i, item in enumerate(self.data):
70-
if item is not None:
71-
item._remove_usage(self._node, i)
72-
self.data.clear()
73-
74-
def __setitem__(self, i: int, item: _core.Value | None) -> None:
75-
"""Replace an input to the node."""
76-
if i < -len(self.data) or i >= len(self.data):
77-
raise ValueError(f"index out of range: {i}")
78-
if i < 0:
79-
i += len(self.data)
80-
assert i >= 0
81-
old_input = self.data[i]
82-
if old_input is not None:
83-
old_input._remove_usage(self._node, i) # pylint: disable=protected-access
84-
if item is not None:
85-
item._add_usage(self._node, i) # pylint: disable=protected-access
86-
self.data[i] = item
87-
88-
def unsupported(self, *_args, **_kwargs):
89-
raise NotImplementedError("Method is not supported")
90-
91-
__lt__ = unsupported
92-
__le__ = unsupported
93-
__gt__ = unsupported
94-
__ge__ = unsupported
95-
__add__ = unsupported
96-
__radd__ = unsupported
97-
__iadd__ = unsupported
98-
__mul__ = unsupported
99-
reverse = unsupported
100-
sort = unsupported
101-
102-
103-
class NodeOutputs(collections.UserList[_core.Value]):
104-
def __init__(self, node: _core.Node, initlist=None):
105-
super().__init__(initlist)
106-
self._node = node
39+
def _unset_graph(self, value: _core.Value) -> None:
40+
"""Unset the graph for the value."""
41+
if self._typ == "input":
42+
value._graph_input_of = None
43+
else:
44+
value._graph_output_of = None
10745

10846
def append(self, item: _core.Value) -> None:
109-
"""Add a new output to the node."""
110-
if item.producer() is not None and item.producer() is not self._node:
111-
raise NotImplementedError(
112-
f"Output already has a producer that is not this node ({item.producer()}). "
113-
"An output value can be owned by only one node throughout its lifetime. "
114-
"Instead, create a new value and assign it to the output. Replace all usages of the old value with the new one."
115-
)
116-
item._producer = self._node # pylint: disable=protected-access
117-
item._index = len(self.data) # pylint: disable=protected-access
118-
self.data.append(item)
47+
"""Add a new input to the graph."""
48+
super().append(item)
49+
self._set_graph(item)
11950

12051
def extend(self, other) -> None:
52+
"""Extend the list of inputs or outputs."""
53+
super().extend(other)
12154
for item in other:
122-
self.append(item)
55+
self._set_graph(item)
56+
57+
def insert(self, i: int, item: _core.Value) -> None:
58+
"""Insert an input/output to the graph."""
59+
super().insert(i, item)
60+
self._set_graph(item)
12361

124-
def pop(self, i: int = -1) -> _core.Value | None:
125-
item = self.data[i]
126-
if i == -1:
127-
# Remove the last item. No usages need to be updated
128-
if item.uses():
129-
raise ValueError(
130-
f"Cannot remove output {item} because it is still used by other nodes."
131-
)
132-
return self.data.pop()
62+
def pop(self, i: int = -1) -> _core.Value:
63+
"""Remove an input/output from the graph."""
64+
value = super().pop(i)
65+
self._unset_graph(value)
66+
return value
13367

134-
for j, output in enumerate(self.data):
135-
output._index = j # pylint: disable=protected-access
68+
def remove(self, item: _core.Value) -> None:
69+
"""Remove an input/output from the graph."""
70+
super().remove(item)
71+
self._unset_graph(item)
13672

13773
def clear(self) -> None:
13874
"""Clear the list."""
139-
for _ in range(len(self.data)):
140-
self.pop()
141-
142-
def __setitem__(self, i: int, item: _core.Value) -> None:
143-
if item is None:
144-
raise NotImplementedError(
145-
"An output cannot be None. To remove a trailing output, use pop(). "
146-
"To remove an output in the middle, set its name to an empty string instead."
147-
)
148-
self.pop(i)
149-
self.insert(i, item)
150-
151-
def insert(self, i: int, item: _core.Value) -> None:
152-
"""Replace an output to the node."""
153-
if i < -len(self.data) or i >= len(self.data):
154-
raise ValueError(f"index out of range: {i}")
155-
if i < 0:
156-
i += len(self.data)
157-
assert i >= 0
158-
if item.producer() is not None and item.producer() is not self._node:
159-
raise NotImplementedError(
160-
f"Output already has a producer that is not this node ({item.producer()}). "
161-
"An output value can be owned by only one node throughout its lifetime. "
162-
"Instead, create a new value and assign it to the output. Replace all usages of the old value with the new one."
163-
)
164-
165-
item._producer = self._node # pylint: disable=protected-access
166-
167-
# Update the index of the item being replaced
168-
self.data.insert(i, item)
169-
for j, output in enumerate(self.data):
170-
output._index = j # pylint: disable=protected-access
171-
172-
173-
def unsupported(self, *_args, **_kwargs):
174-
raise NotImplementedError("Method is not supported")
175-
176-
__lt__ = unsupported
177-
__le__ = unsupported
178-
__gt__ = unsupported
179-
__ge__ = unsupported
180-
__add__ = unsupported
181-
__radd__ = unsupported
182-
__iadd__ = unsupported
183-
__mul__ = unsupported
184-
# NOTE: We don't support insertion and removal of items in the middle of the list
185-
# because in ONNX outputs are positional and index dependent. To remove an output in
186-
# the middle, set its name to an empty empty string instead.
187-
reverse = unsupported
188-
sort = unsupported
189-
remove = unsupported
75+
for value in self.data:
76+
self._unset_graph(value)
77+
super().clear()
78+
79+
def __setitem__(self, i, item) -> None:
80+
"""Replace an input/output to the node."""
81+
if isinstance(item, Iterable) and isinstance(i, slice):
82+
# Modify a slice of the list
83+
for value in self.data[i]:
84+
self._unset_graph(value)
85+
for value in item:
86+
self._set_graph(value)
87+
super().__setitem__(i, item)
88+
return
89+
elif isinstance(item, _core.Value) and isinstance(i, SupportsIndex):
90+
# Replace a single item
91+
self._unset_graph(self.data[i])
92+
self._set_graph(item)
93+
super().__setitem__(i, item)
94+
return
95+
96+
raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")

0 commit comments

Comments
 (0)