2
2
# Licensed under the MIT License.
3
3
"""Convenience methods for constructing the IR."""
4
4
5
- # NOTE: This is a temporary solution for constructing the IR. It should be replaced
6
- # with a more permanent solution in the future.
7
-
8
5
from __future__ import annotations
9
6
10
- from typing import Any , Iterable , Iterator , List , Mapping , Optional , Sequence , Tuple
7
+ from typing import (
8
+ Any ,
9
+ Mapping ,
10
+ Optional ,
11
+ Sequence ,
12
+ Tuple ,
13
+ )
11
14
12
15
from onnxscript import ir
13
16
from onnxscript .ir import _convenience
14
17
18
+ # A type representing the domains/versions used in creating nodes in IR.
19
+ UsedOpsets = set [Tuple [str , Optional [int ]]]
20
+
21
+
22
+ class Tape :
23
+ """Tape class.
24
+
25
+ A tape is a recorder that collects nodes and initializers that are created so
26
+ that they can be used for creating a graph.
27
+
28
+ Example::
29
+ from onnxscript import ir
30
+
31
+ tape = ir.tape.Tape()
32
+ a = tape.initializer(ir.tensor([1, 2, 3], name="a"))
33
+ b: ir.Value = ...
34
+ c: ir.Value = ...
35
+ x = tape.op("Add", [a, b], attributes={"alpha": 1.0})
36
+ y = tape.op("Mul", [x, c], attributes={"beta": 2.0})
37
+ model = ir.Model(
38
+ graph := ir.Graph(
39
+ inputs=[b, c],
40
+ outputs=[y],
41
+ nodes=tape.nodes,
42
+ initializers=tape.initializers
43
+ opset_imports={"": 20},
44
+ ),
45
+ ir_version=10,
46
+ )
15
47
16
- class Tape (Iterable [ir .Node ]):
17
- """A tape for recording nodes that are created."""
48
+ Attributes:
49
+ graph_like: The graph to append the new nodes and initializers to. When
50
+ it is None, the nodes and initializers are creating without owned by a graph.
51
+ Initializers will not be added to functions because it is not supported by ONNX.
52
+ """
18
53
19
- def __init__ (self ) -> None :
54
+ def __init__ (self , graph_like : ir . Graph | ir . Function | None = None ) -> None :
20
55
self ._nodes : list [ir .Node ] = []
21
56
self ._initializers : list [ir .Value ] = []
57
+ self ._used_opsets : UsedOpsets = set ()
58
+ self .graph_like = graph_like
22
59
23
- def __iter__ (self ) -> Iterator [ ir . Node ] :
24
- return iter ( self ._nodes )
60
+ def __repr__ (self ) -> str :
61
+ return f"Tape(nodes= { self ._nodes } , initializers= { self . _initializers } )"
25
62
26
63
@property
27
64
def nodes (self ) -> Sequence [ir .Node ]:
@@ -31,19 +68,43 @@ def nodes(self) -> Sequence[ir.Node]:
31
68
def initializers (self ) -> Sequence [ir .Value ]:
32
69
return tuple (self ._initializers )
33
70
71
+ @property
72
+ def used_opsets (self ) -> UsedOpsets :
73
+ return self ._used_opsets
74
+
34
75
def op (
35
76
self ,
36
77
op_type : str ,
37
78
inputs : Sequence [ir .Value | None ],
38
79
attributes : Mapping [str , _convenience .SupportedAttrTypes ] | None = None ,
80
+ * ,
39
81
domain : str = "" ,
82
+ overload : str = "" ,
83
+ version : int | None = None ,
84
+ graph : ir .Graph | None = None ,
85
+ name : str | None = None ,
86
+ doc_string : str | None = None ,
87
+ metadata_props : dict [str , str ] | None = None ,
40
88
) -> ir .Value :
41
89
if attributes is None :
42
90
attrs : Sequence [ir .Attr | ir .RefAttr ] = ()
43
91
else :
44
92
attrs = _convenience .convert_attributes (attributes )
45
- node = ir .Node (domain , op_type , inputs , attributes = attrs , num_outputs = 1 )
93
+ node = ir .Node (
94
+ domain ,
95
+ op_type ,
96
+ inputs ,
97
+ attributes = attrs ,
98
+ num_outputs = 1 ,
99
+ overload = overload ,
100
+ version = version ,
101
+ graph = graph or self .graph_like ,
102
+ name = name ,
103
+ doc_string = doc_string ,
104
+ metadata_props = metadata_props ,
105
+ )
46
106
self ._nodes .append (node )
107
+ self ._used_opsets .add ((domain , version ))
47
108
48
109
return node .outputs [0 ]
49
110
@@ -55,13 +116,32 @@ def op_multi_output(
55
116
* ,
56
117
num_outputs : int ,
57
118
domain : str = "" ,
119
+ overload : str = "" ,
120
+ version : int | None = None ,
121
+ graph : ir .Graph | None = None ,
122
+ name : str | None = None ,
123
+ doc_string : str | None = None ,
124
+ metadata_props : dict [str , str ] | None = None ,
58
125
) -> Sequence [ir .Value ]:
59
126
if attributes is None :
60
127
attrs : Sequence [ir .Attr | ir .RefAttr ] = ()
61
128
else :
62
129
attrs = _convenience .convert_attributes (attributes )
63
- node = ir .Node (domain , op_type , inputs , attributes = attrs , num_outputs = num_outputs )
130
+ node = ir .Node (
131
+ domain ,
132
+ op_type ,
133
+ inputs ,
134
+ attributes = attrs ,
135
+ num_outputs = num_outputs ,
136
+ overload = overload ,
137
+ version = version ,
138
+ graph = graph or self .graph_like ,
139
+ name = name ,
140
+ doc_string = doc_string ,
141
+ metadata_props = metadata_props ,
142
+ )
64
143
self ._nodes .append (node )
144
+ self ._used_opsets .add ((domain , version ))
65
145
66
146
return node .outputs
67
147
@@ -74,20 +154,14 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.
74
154
name = name , shape = shape , type = ir .TensorType (tensor .dtype ), const_value = tensor
75
155
)
76
156
self ._initializers .append (value )
157
+ if isinstance (self .graph_like , ir .Graph ):
158
+ self .graph_like .register_initializer (value )
77
159
return value
78
160
79
161
80
- # A type representing the domains/versions used in creating nodes in IR.
81
- UsedOpsets = List [Tuple [str , Optional [int ]]]
82
-
83
-
84
162
class Builder (Tape ):
85
163
"""An extension of the tape that provides a more convenient API for constructing the IR."""
86
164
87
- def __init__ (self ):
88
- super ().__init__ ()
89
- self ._used_opsets : UsedOpsets = []
90
-
91
165
def __getattr__ (self , op_type : str ) -> Any :
92
166
return lambda * args , ** kwargs : self ._make_node (op_type , args , kwargs )
93
167
@@ -101,20 +175,22 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
101
175
assert isinstance (outputs , int )
102
176
num_outputs = outputs
103
177
104
- self ._used_opsets .append ((domain , version ))
105
178
if num_outputs == 1 :
106
- value = super ().op (op_type , inputs = inputs , attributes = kwargs , domain = domain )
179
+ value = super ().op (
180
+ op_type , inputs = inputs , attributes = kwargs , domain = domain , version = version
181
+ )
107
182
if isinstance (outputs , Sequence ):
108
183
value .name = outputs [0 ]
109
184
return value
110
185
values = super ().op_multi_output (
111
- op_type , inputs = inputs , attributes = kwargs , domain = domain , num_outputs = num_outputs
186
+ op_type ,
187
+ inputs = inputs ,
188
+ attributes = kwargs ,
189
+ domain = domain ,
190
+ version = version ,
191
+ num_outputs = num_outputs ,
112
192
)
113
193
if isinstance (outputs , Sequence ):
114
194
for value , name in zip (values , outputs ):
115
195
value .name = name
116
196
return values
117
-
118
- @property
119
- def used_opsets (self ) -> UsedOpsets :
120
- return self ._used_opsets
0 commit comments