1
1
# Copyright (c) Microsoft Corporation.
2
2
# Licensed under the MIT License.
3
- """Tracked lists for graph and node IO ."""
3
+ """Tracked containers for graph."""
4
4
5
5
from __future__ import annotations
6
6
7
+ __all__ = [
8
+ "GraphInputs" ,
9
+ "GraphOutputs" ,
10
+ ]
11
+
7
12
import collections
8
- from typing import TYPE_CHECKING , Iterable , Literal , SupportsIndex
13
+ from typing import TYPE_CHECKING , Iterable , SupportsIndex
14
+
15
+ import onnxscript
9
16
10
17
if TYPE_CHECKING :
11
18
from onnxscript .ir import _core
12
19
13
20
14
- class GraphIO (collections .UserList [_core .Value ]):
21
+ class _GraphIO (collections .UserList [_core .Value ]):
15
22
"""The inputs and outputs of a Graph."""
16
23
17
- def __init__ (self , graph : _core .Graph , typ : Literal [ "input" , "output" ], initlist = None ):
24
+ def __init__ (self , graph : _core .Graph , initlist = None ):
18
25
super ().__init__ (initlist )
19
26
self ._graph = graph
20
- assert typ in {"intput" , "output" }
21
- self ._typ = typ
27
+
28
+ def _check_invariance (self ) -> None :
29
+ """Check the invariance of the graph."""
30
+ raise NotImplementedError
22
31
23
32
def _set_graph (self , value : _core .Value ) -> None :
24
33
"""Set the graph for the value."""
25
- if value ._graph_input_of is not None and value ._graph_input_of is not self ._graph :
26
- raise ValueError (
27
- f"Value '{ value } ' is already an input of a different graph: { value ._graph_input_of !r} "
28
- )
29
- if value ._graph_output_of is not None and value ._graph_output_of is not self ._graph :
30
- raise ValueError (
31
- f"Value '{ value } ' is already an output of a different graph: { value ._graph_output_of !r} "
32
- )
33
-
34
- if self ._typ == "input" :
35
- value ._graph_input_of = self ._graph
36
- else :
37
- value ._graph_output_of = self ._graph
34
+ raise NotImplementedError
38
35
39
36
def _unset_graph (self , value : _core .Value ) -> None :
40
37
"""Unset the graph for the value."""
41
- if self ._typ == "input" :
42
- value ._graph_input_of = None
43
- else :
44
- value ._graph_output_of = None
38
+ raise NotImplementedError
45
39
46
40
def append (self , item : _core .Value ) -> None :
47
41
"""Add a new input to the graph."""
48
42
super ().append (item )
49
43
self ._set_graph (item )
44
+ self ._check_invariance ()
50
45
51
46
def extend (self , other ) -> None :
52
47
"""Extend the list of inputs or outputs."""
@@ -58,17 +53,20 @@ def insert(self, i: int, item: _core.Value) -> None:
58
53
"""Insert an input/output to the graph."""
59
54
super ().insert (i , item )
60
55
self ._set_graph (item )
56
+ self ._check_invariance ()
61
57
62
58
def pop (self , i : int = - 1 ) -> _core .Value :
63
59
"""Remove an input/output from the graph."""
64
60
value = super ().pop (i )
65
61
self ._unset_graph (value )
62
+ self ._check_invariance ()
66
63
return value
67
64
68
65
def remove (self , item : _core .Value ) -> None :
69
66
"""Remove an input/output from the graph."""
70
67
super ().remove (item )
71
68
self ._unset_graph (item )
69
+ self ._check_invariance ()
72
70
73
71
def clear (self ) -> None :
74
72
"""Clear the list."""
@@ -85,12 +83,76 @@ def __setitem__(self, i, item) -> None:
85
83
for value in item :
86
84
self ._set_graph (value )
87
85
super ().__setitem__ (i , item )
86
+ self ._check_invariance ()
88
87
return
89
88
elif isinstance (item , _core .Value ) and isinstance (i , SupportsIndex ):
90
89
# Replace a single item
91
90
self ._unset_graph (self .data [i ])
92
91
self ._set_graph (item )
93
92
super ().__setitem__ (i , item )
93
+ self ._check_invariance ()
94
94
return
95
95
96
96
raise TypeError (f"Invalid types for __setitem__: { type (i )} and { type (item )} " )
97
+
98
+
99
+ class GraphInputs (_GraphIO ):
100
+ """The inputs of a Graph."""
101
+
102
+ def __init__ (self , graph : _core .Graph , initlist = None ):
103
+ super ().__init__ (graph , initlist )
104
+
105
+ def _check_invariance (self ) -> None :
106
+ """Check the invariance of the graph."""
107
+ if not onnxscript .DEBUG :
108
+ return
109
+ for value in self .data :
110
+ if value ._graph_input_of is self ._graph :
111
+ continue
112
+ raise ValueError (
113
+ f"Invariance error: Value '{ value } ' is not an input of the graph: { self ._graph !r} "
114
+ )
115
+
116
+ def _set_graph (self , value : _core .Value ) -> None :
117
+ """Set the graph for the value."""
118
+ if value ._graph_input_of is not None and value ._graph_input_of is not self ._graph :
119
+ raise ValueError (
120
+ f"Value '{ value } ' is already an input of a different graph: { value ._graph_input_of !r} "
121
+ )
122
+
123
+ value ._graph_input_of = self ._graph
124
+
125
+ def _unset_graph (self , value : _core .Value ) -> None :
126
+ """Unset the graph for the value."""
127
+ value ._graph_input_of = None
128
+
129
+
130
+ class GraphOutputs (_GraphIO ):
131
+ """The outputs of a Graph."""
132
+
133
+ def __init__ (self , graph : _core .Graph , initlist = None ):
134
+ super ().__init__ (graph , initlist )
135
+
136
+ def _check_invariance (self ) -> None :
137
+ """Check the invariance of the graph."""
138
+ if not onnxscript .DEBUG :
139
+ return
140
+ for value in self .data :
141
+ if value ._graph_output_of is self ._graph :
142
+ continue
143
+ raise ValueError (
144
+ f"Invariance error: Value '{ value } ' is not an output of the graph: { self ._graph !r} "
145
+ )
146
+
147
+ def _set_graph (self , value : _core .Value ) -> None :
148
+ """Set the graph for the value."""
149
+ if value ._graph_output_of is not None and value ._graph_output_of is not self ._graph :
150
+ raise ValueError (
151
+ f"Value '{ value } ' is already an output of a different graph: { value ._graph_output_of !r} "
152
+ )
153
+
154
+ value ._graph_output_of = self ._graph
155
+
156
+ def _unset_graph (self , value : _core .Value ) -> None :
157
+ """Unset the graph for the value."""
158
+ value ._graph_output_of = None
0 commit comments