Skip to content

Commit f4da452

Browse files
committed
tracked list
1 parent 605e06e commit f4da452

File tree

2 files changed

+115
-7
lines changed

2 files changed

+115
-7
lines changed

onnxscript/ir/_core.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,9 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
17571757

17581758
__slots__ = (
17591759
"_const_value",
1760+
"_graph_initializer_of",
1761+
"_graph_input_of",
1762+
"_graph_output_of",
17601763
"_index",
17611764
"_metadata",
17621765
"_metadata_props",
@@ -1808,6 +1811,13 @@ def __init__(
18081811
self._uses: dict[Usage, None] = {}
18091812
self.doc_string = doc_string
18101813

1814+
# The graph this value belongs to. It is set *only* when the value is added as
1815+
# a graph input, graph output, or initializer.
1816+
# The three properties can only be set by the Graph class (GraphIO).
1817+
self._graph_initializer_of: Graph | None = None
1818+
self._graph_input_of: Graph | None = None
1819+
self._graph_output_of: Graph | None = None
1820+
18111821
def __repr__(self) -> str:
18121822
value_name = self.name if self.name else "anonymous:" + str(id(self))
18131823
type_text = f", type={self.type!r}" if self.type is not None else ""
@@ -1986,15 +1996,17 @@ def metadata_props(self) -> dict[str, str]:
19861996
self._metadata_props = {}
19871997
return self._metadata_props
19881998

1999+
def is_initializer(self) -> bool:
2000+
"""Whether the value is an initializer."""
2001+
return self._graph_initializer_of is not None
2002+
2003+
def is_graph_input(self) -> bool:
2004+
"""Whether the value is an input of a graph."""
2005+
return self._graph_input_of is not None
2006+
19892007
def is_graph_output(self) -> bool:
19902008
"""Whether the value is an output of a graph."""
1991-
if (producer := self.producer()) is None:
1992-
return False
1993-
if (graph := producer.graph) is None:
1994-
return False
1995-
# Cannot use `in` because __eq__ may be defined by subclasses, even though
1996-
# it is not recommended
1997-
return any(output is self for output in graph.outputs)
2009+
return self._graph_output_of is not None
19982010

19992011

20002012
def Input(

onnxscript/ir/_tracked_lists.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Tracked lists for graph and node IO."""
4+
5+
from __future__ import annotations
6+
7+
import collections
8+
from typing import TYPE_CHECKING, Iterable, Literal, SupportsIndex
9+
10+
if TYPE_CHECKING:
11+
from onnxscript.ir import _core
12+
13+
14+
class GraphIO(collections.UserList[_core.Value]):
15+
"""The inputs and outputs of a Graph."""
16+
17+
def __init__(self, graph: _core.Graph, typ: Literal["input", "output"], initlist=None):
18+
super().__init__(initlist)
19+
self._graph = graph
20+
assert typ in {"intput", "output"}
21+
self._typ = typ
22+
23+
def _set_graph(self, value: _core.Value) -> None:
24+
"""Set the graph for the value."""
25+
if value._graph_input_of is not None and value._graph_input_of is not self._graph:
26+
raise ValueError(
27+
f"Value '{value}' is already an input of a different graph: {value._graph_input_of!r}"
28+
)
29+
if value._graph_output_of is not None and value._graph_output_of is not self._graph:
30+
raise ValueError(
31+
f"Value '{value}' is already an output of a different graph: {value._graph_output_of!r}"
32+
)
33+
34+
if self._typ == "input":
35+
value._graph_input_of = self._graph
36+
else:
37+
value._graph_output_of = self._graph
38+
39+
def _unset_graph(self, value: _core.Value) -> None:
40+
"""Unset the graph for the value."""
41+
if self._typ == "input":
42+
value._graph_input_of = None
43+
else:
44+
value._graph_output_of = None
45+
46+
def append(self, item: _core.Value) -> None:
47+
"""Add a new input to the graph."""
48+
super().append(item)
49+
self._set_graph(item)
50+
51+
def extend(self, other) -> None:
52+
"""Extend the list of inputs or outputs."""
53+
super().extend(other)
54+
for item in other:
55+
self._set_graph(item)
56+
57+
def insert(self, i: int, item: _core.Value) -> None:
58+
"""Insert an input/output to the graph."""
59+
super().insert(i, item)
60+
self._set_graph(item)
61+
62+
def pop(self, i: int = -1) -> _core.Value:
63+
"""Remove an input/output from the graph."""
64+
value = super().pop(i)
65+
self._unset_graph(value)
66+
return value
67+
68+
def remove(self, item: _core.Value) -> None:
69+
"""Remove an input/output from the graph."""
70+
super().remove(item)
71+
self._unset_graph(item)
72+
73+
def clear(self) -> None:
74+
"""Clear the list."""
75+
for value in self.data:
76+
self._unset_graph(value)
77+
super().clear()
78+
79+
def __setitem__(self, i, item) -> None:
80+
"""Replace an input/output to the node."""
81+
if isinstance(item, Iterable) and isinstance(i, slice):
82+
# Modify a slice of the list
83+
for value in self.data[i]:
84+
self._unset_graph(value)
85+
for value in item:
86+
self._set_graph(value)
87+
super().__setitem__(i, item)
88+
return
89+
elif isinstance(item, _core.Value) and isinstance(i, SupportsIndex):
90+
# Replace a single item
91+
self._unset_graph(self.data[i])
92+
self._set_graph(item)
93+
super().__setitem__(i, item)
94+
return
95+
96+
raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")

0 commit comments

Comments
 (0)