Skip to content

Commit ad64b58

Browse files
authored
[IR] Expose the Tape module (#2127)
Expose the `Tape` class under `ir.tape` for simplifying graph construction in the IR. This is a secondary API for convenience. I updated `onnxscript/ir/passes/common/shape_inference_test.py` to demonstrate usage. I added an optional reference to the graph from `Tape`. When the graph is specified, the added nodes are appended to the graph. This provides users the ability to examine the graph as they build it up using Tape.
1 parent 971170d commit ad64b58

File tree

7 files changed

+239
-61
lines changed

7 files changed

+239
-61
lines changed

onnxscript/ir/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"traversal",
99
"convenience",
1010
"external_data",
11+
"tape",
1112
# IR classes
1213
"Tensor",
1314
"ExternalTensor",
@@ -80,7 +81,7 @@
8081
"save",
8182
]
8283

83-
from onnxscript.ir import convenience, external_data, passes, serde, traversal
84+
from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal
8485
from onnxscript.ir._convenience._constructors import node, tensor
8586
from onnxscript.ir._core import (
8687
Attr,

onnxscript/ir/_core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ def __init__(
11351135
num_outputs: int | None = None,
11361136
outputs: Sequence[Value] | None = None,
11371137
version: int | None = None,
1138-
graph: Graph | None = None,
1138+
graph: Graph | Function | None = None,
11391139
name: str | None = None,
11401140
doc_string: str | None = None,
11411141
metadata_props: dict[str, str] | None = None,
@@ -1187,7 +1187,7 @@ def __init__(
11871187
self._version: int | None = version
11881188
self._metadata: _metadata.MetadataStore | None = None
11891189
self._metadata_props: dict[str, str] | None = metadata_props
1190-
self._graph: Graph | None = graph
1190+
self._graph: Graph | Function | None = graph
11911191
self.doc_string = doc_string
11921192

11931193
# Add the node as a use of the inputs
@@ -1432,11 +1432,11 @@ def metadata_props(self) -> dict[str, str]:
14321432
return self._metadata_props
14331433

14341434
@property
1435-
def graph(self) -> Graph | None:
1435+
def graph(self) -> Graph | Function | None:
14361436
return self._graph
14371437

14381438
@graph.setter
1439-
def graph(self, value: Graph | None) -> None:
1439+
def graph(self, value: Graph | Function | None) -> None:
14401440
self._graph = value
14411441

14421442
def op_identifier(self) -> _protocols.OperatorIdentifier:
@@ -2162,15 +2162,15 @@ def sort(self) -> None:
21622162
21632163
This sort is stable. It preserves the original order as much as possible.
21642164
2165-
Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort
2165+
Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort
21662166
21672167
Raises:
21682168
ValueError: If the graph contains a cycle, making topological sorting impossible.
21692169
"""
21702170
# Obtain all nodes from the graph and its subgraphs for sorting
21712171
nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self))
21722172
# Store the sorted nodes of each subgraph
2173-
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
2173+
sorted_nodes_by_graph: dict[Graph | Function, list[Node]] = {
21742174
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
21752175
}
21762176
# TODO: Explain why we need to store direct predecessors and children and why
@@ -2193,7 +2193,7 @@ def add_predecessor(child: Node, predecessor: Node | None) -> None:
21932193
node_depth[predecessor] += 1
21942194

21952195
# 1. Build the direct predecessors of each node and the depth of each node
2196-
# for sorting topolocally using Kahn's algorithm.
2196+
# for sorting topologically using Kahn's algorithm.
21972197
# Note that when a node contains graph attributes (aka. has subgraphs),
21982198
# we consider all nodes in the subgraphs *predecessors* of this node. This
21992199
# way we ensure the implicit dependencies of the subgraphs are captured
@@ -2718,11 +2718,11 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
27182718
"""
27192719
self._graph.remove(nodes, safe=safe)
27202720

2721-
def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
2721+
def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
27222722
"""Insert new nodes after the given node in O(#new_nodes) time."""
27232723
self._graph.insert_after(node, new_nodes)
27242724

2725-
def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
2725+
def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
27262726
"""Insert new nodes before the given node in O(#new_nodes) time."""
27272727
self._graph.insert_before(node, new_nodes)
27282728

onnxscript/ir/_protocols.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,15 @@ def remove(self, node: NodeProtocol, /) -> None:
320320
"""Remove a node from the graph."""
321321
...
322322

323-
def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
323+
def insert_after(
324+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
325+
) -> None:
324326
"""Insert new nodes after the given node."""
325327
...
326328

327-
def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
329+
def insert_before(
330+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
331+
) -> None:
328332
"""Insert new nodes before the given node."""
329333
...
330334

@@ -589,11 +593,15 @@ def remove(self, node: NodeProtocol, /) -> None:
589593
"""Remove a node from the function."""
590594
...
591595

592-
def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
596+
def insert_after(
597+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
598+
) -> None:
593599
"""Insert new nodes after the given node."""
594600
...
595601

596-
def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
602+
def insert_before(
603+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
604+
) -> None:
597605
"""Insert new nodes before the given node."""
598606
...
599607

onnxscript/ir/_tape.py

Lines changed: 102 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,63 @@
22
# Licensed under the MIT License.
33
"""Convenience methods for constructing the IR."""
44

5-
# NOTE: This is a temporary solution for constructing the IR. It should be replaced
6-
# with a more permanent solution in the future.
7-
85
from __future__ import annotations
96

10-
from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple
7+
from typing import (
8+
Any,
9+
Mapping,
10+
Optional,
11+
Sequence,
12+
Tuple,
13+
)
1114

1215
from onnxscript import ir
1316
from onnxscript.ir import _convenience
1417

18+
# A type representing the domains/versions used in creating nodes in IR.
19+
UsedOpsets = set[Tuple[str, Optional[int]]]
20+
21+
22+
class Tape:
23+
"""Tape class.
24+
25+
A tape is a recorder that collects nodes and initializers that are created so
26+
that they can be used for creating a graph.
27+
28+
Example::
29+
from onnxscript import ir
30+
31+
tape = ir.tape.Tape()
32+
a = tape.initializer(ir.tensor([1, 2, 3], name="a"))
33+
b: ir.Value = ...
34+
c: ir.Value = ...
35+
x = tape.op("Add", [a, b], attributes={"alpha": 1.0})
36+
y = tape.op("Mul", [x, c], attributes={"beta": 2.0})
37+
model = ir.Model(
38+
graph := ir.Graph(
39+
inputs=[b, c],
40+
outputs=[y],
41+
nodes=tape.nodes,
42+
initializers=tape.initializers
43+
opset_imports={"": 20},
44+
),
45+
ir_version=10,
46+
)
1547
16-
class Tape(Iterable[ir.Node]):
17-
"""A tape for recording nodes that are created."""
48+
Attributes:
49+
graph_like: The graph to append the new nodes and initializers to. When
50+
it is None, the nodes and initializers are creating without owned by a graph.
51+
Initializers will not be added to functions because it is not supported by ONNX.
52+
"""
1853

19-
def __init__(self) -> None:
54+
def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None:
2055
self._nodes: list[ir.Node] = []
2156
self._initializers: list[ir.Value] = []
57+
self._used_opsets: UsedOpsets = set()
58+
self.graph_like = graph_like
2259

23-
def __iter__(self) -> Iterator[ir.Node]:
24-
return iter(self._nodes)
60+
def __repr__(self) -> str:
61+
return f"Tape(nodes={self._nodes}, initializers={self._initializers})"
2562

2663
@property
2764
def nodes(self) -> Sequence[ir.Node]:
@@ -31,19 +68,43 @@ def nodes(self) -> Sequence[ir.Node]:
3168
def initializers(self) -> Sequence[ir.Value]:
3269
return tuple(self._initializers)
3370

71+
@property
72+
def used_opsets(self) -> UsedOpsets:
73+
return self._used_opsets
74+
3475
def op(
3576
self,
3677
op_type: str,
3778
inputs: Sequence[ir.Value | None],
3879
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
80+
*,
3981
domain: str = "",
82+
overload: str = "",
83+
version: int | None = None,
84+
graph: ir.Graph | None = None,
85+
name: str | None = None,
86+
doc_string: str | None = None,
87+
metadata_props: dict[str, str] | None = None,
4088
) -> ir.Value:
4189
if attributes is None:
4290
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
4391
else:
4492
attrs = _convenience.convert_attributes(attributes)
45-
node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1)
93+
node = ir.Node(
94+
domain,
95+
op_type,
96+
inputs,
97+
attributes=attrs,
98+
num_outputs=1,
99+
overload=overload,
100+
version=version,
101+
graph=graph or self.graph_like,
102+
name=name,
103+
doc_string=doc_string,
104+
metadata_props=metadata_props,
105+
)
46106
self._nodes.append(node)
107+
self._used_opsets.add((domain, version))
47108

48109
return node.outputs[0]
49110

@@ -55,13 +116,32 @@ def op_multi_output(
55116
*,
56117
num_outputs: int,
57118
domain: str = "",
119+
overload: str = "",
120+
version: int | None = None,
121+
graph: ir.Graph | None = None,
122+
name: str | None = None,
123+
doc_string: str | None = None,
124+
metadata_props: dict[str, str] | None = None,
58125
) -> Sequence[ir.Value]:
59126
if attributes is None:
60127
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
61128
else:
62129
attrs = _convenience.convert_attributes(attributes)
63-
node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs)
130+
node = ir.Node(
131+
domain,
132+
op_type,
133+
inputs,
134+
attributes=attrs,
135+
num_outputs=num_outputs,
136+
overload=overload,
137+
version=version,
138+
graph=graph or self.graph_like,
139+
name=name,
140+
doc_string=doc_string,
141+
metadata_props=metadata_props,
142+
)
64143
self._nodes.append(node)
144+
self._used_opsets.add((domain, version))
65145

66146
return node.outputs
67147

@@ -74,20 +154,14 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.
74154
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
75155
)
76156
self._initializers.append(value)
157+
if isinstance(self.graph_like, ir.Graph):
158+
self.graph_like.register_initializer(value)
77159
return value
78160

79161

80-
# A type representing the domains/versions used in creating nodes in IR.
81-
UsedOpsets = List[Tuple[str, Optional[int]]]
82-
83-
84162
class Builder(Tape):
85163
"""An extension of the tape that provides a more convenient API for constructing the IR."""
86164

87-
def __init__(self):
88-
super().__init__()
89-
self._used_opsets: UsedOpsets = []
90-
91165
def __getattr__(self, op_type: str) -> Any:
92166
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)
93167

@@ -101,20 +175,22 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
101175
assert isinstance(outputs, int)
102176
num_outputs = outputs
103177

104-
self._used_opsets.append((domain, version))
105178
if num_outputs == 1:
106-
value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
179+
value = super().op(
180+
op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version
181+
)
107182
if isinstance(outputs, Sequence):
108183
value.name = outputs[0]
109184
return value
110185
values = super().op_multi_output(
111-
op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
186+
op_type,
187+
inputs=inputs,
188+
attributes=kwargs,
189+
domain=domain,
190+
version=version,
191+
num_outputs=num_outputs,
112192
)
113193
if isinstance(outputs, Sequence):
114194
for value, name in zip(values, outputs):
115195
value.name = name
116196
return values
117-
118-
@property
119-
def used_opsets(self) -> UsedOpsets:
120-
return self._used_opsets

0 commit comments

Comments
 (0)