Skip to content

[IR] Expose the Tape module #2127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"traversal",
"convenience",
"external_data",
"tape",
# IR classes
"Tensor",
"ExternalTensor",
Expand Down Expand Up @@ -80,7 +81,7 @@
"save",
]

from onnxscript.ir import convenience, external_data, passes, serde, traversal
from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal
from onnxscript.ir._convenience._constructors import node, tensor
from onnxscript.ir._core import (
Attr,
Expand Down
18 changes: 9 additions & 9 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def __init__(
num_outputs: int | None = None,
outputs: Sequence[Value] | None = None,
version: int | None = None,
graph: Graph | None = None,
graph: Graph | Function | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
Expand Down Expand Up @@ -1187,7 +1187,7 @@ def __init__(
self._version: int | None = version
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = metadata_props
self._graph: Graph | None = graph
self._graph: Graph | Function | None = graph
self.doc_string = doc_string

# Add the node as a use of the inputs
Expand Down Expand Up @@ -1432,11 +1432,11 @@ def metadata_props(self) -> dict[str, str]:
return self._metadata_props

@property
def graph(self) -> Graph | None:
def graph(self) -> Graph | Function | None:
return self._graph

@graph.setter
def graph(self, value: Graph | None) -> None:
def graph(self, value: Graph | Function | None) -> None:
self._graph = value

def op_identifier(self) -> _protocols.OperatorIdentifier:
Expand Down Expand Up @@ -2162,15 +2162,15 @@ def sort(self) -> None:

This sort is stable. It preserves the original order as much as possible.

Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort
Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort

Raises:
ValueError: If the graph contains a cycle, making topological sorting impossible.
"""
# Obtain all nodes from the graph and its subgraphs for sorting
nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self))
# Store the sorted nodes of each subgraph
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
sorted_nodes_by_graph: dict[Graph | Function, list[Node]] = {
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
}
# TODO: Explain why we need to store direct predecessors and children and why
Expand All @@ -2193,7 +2193,7 @@ def add_predecessor(child: Node, predecessor: Node | None) -> None:
node_depth[predecessor] += 1

# 1. Build the direct predecessors of each node and the depth of each node
# for sorting topolocally using Kahn's algorithm.
# for sorting topologically using Kahn's algorithm.
# Note that when a node contains graph attributes (aka. has subgraphs),
# we consider all nodes in the subgraphs *predecessors* of this node. This
# way we ensure the implicit dependencies of the subgraphs are captured
Expand Down Expand Up @@ -2718,11 +2718,11 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
"""
self._graph.remove(nodes, safe=safe)

def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
"""Insert new nodes after the given node in O(#new_nodes) time."""
self._graph.insert_after(node, new_nodes)

def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
"""Insert new nodes before the given node in O(#new_nodes) time."""
self._graph.insert_before(node, new_nodes)

Expand Down
16 changes: 12 additions & 4 deletions onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,15 @@ def remove(self, node: NodeProtocol, /) -> None:
"""Remove a node from the graph."""
...

def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
def insert_after(
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
) -> None:
"""Insert new nodes after the given node."""
...

def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
def insert_before(
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
) -> None:
"""Insert new nodes before the given node."""
...

Expand Down Expand Up @@ -589,11 +593,15 @@ def remove(self, node: NodeProtocol, /) -> None:
"""Remove a node from the function."""
...

def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
def insert_after(
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
) -> None:
"""Insert new nodes after the given node."""
...

def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
def insert_before(
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
) -> None:
"""Insert new nodes before the given node."""
...

Expand Down
128 changes: 102 additions & 26 deletions onnxscript/ir/_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,63 @@
# Licensed under the MIT License.
"""Convenience methods for constructing the IR."""

# NOTE: This is a temporary solution for constructing the IR. It should be replaced
# with a more permanent solution in the future.

from __future__ import annotations

from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple
from typing import (
Any,
Mapping,
Optional,
Sequence,
Tuple,
)

from onnxscript import ir
from onnxscript.ir import _convenience

# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = set[Tuple[str, Optional[int]]]


class Tape:
"""Tape class.

A tape is a recorder that collects nodes and initializers that are created so
that they can be used for creating a graph.

Example::
from onnxscript import ir

tape = ir.tape.Tape()
a = tape.initializer(ir.tensor([1, 2, 3], name="a"))
b: ir.Value = ...
c: ir.Value = ...
x = tape.op("Add", [a, b], attributes={"alpha": 1.0})
y = tape.op("Mul", [x, c], attributes={"beta": 2.0})
model = ir.Model(
graph := ir.Graph(
inputs=[b, c],
outputs=[y],
nodes=tape.nodes,
initializers=tape.initializers
opset_imports={"": 20},
),
ir_version=10,
)

class Tape(Iterable[ir.Node]):
"""A tape for recording nodes that are created."""
Attributes:
graph_like: The graph to append the new nodes and initializers to. When
it is None, the nodes and initializers are creating without owned by a graph.
Initializers will not be added to functions because it is not supported by ONNX.
"""

def __init__(self) -> None:
def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None:
self._nodes: list[ir.Node] = []
self._initializers: list[ir.Value] = []
self._used_opsets: UsedOpsets = set()
self.graph_like = graph_like

def __iter__(self) -> Iterator[ir.Node]:
return iter(self._nodes)
def __repr__(self) -> str:
return f"Tape(nodes={self._nodes}, initializers={self._initializers})"

@property
def nodes(self) -> Sequence[ir.Node]:
Expand All @@ -31,19 +68,43 @@ def nodes(self) -> Sequence[ir.Node]:
def initializers(self) -> Sequence[ir.Value]:
return tuple(self._initializers)

@property
def used_opsets(self) -> UsedOpsets:
return self._used_opsets

def op(
self,
op_type: str,
inputs: Sequence[ir.Value | None],
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
*,
domain: str = "",
overload: str = "",
version: int | None = None,
graph: ir.Graph | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> ir.Value:
if attributes is None:
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
else:
attrs = _convenience.convert_attributes(attributes)
node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1)
node = ir.Node(
domain,
op_type,
inputs,
attributes=attrs,
num_outputs=1,
overload=overload,
version=version,
graph=graph or self.graph_like,
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
self._nodes.append(node)
self._used_opsets.add((domain, version))

return node.outputs[0]

Expand All @@ -55,13 +116,32 @@ def op_multi_output(
*,
num_outputs: int,
domain: str = "",
overload: str = "",
version: int | None = None,
graph: ir.Graph | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> Sequence[ir.Value]:
if attributes is None:
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
else:
attrs = _convenience.convert_attributes(attributes)
node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs)
node = ir.Node(
domain,
op_type,
inputs,
attributes=attrs,
num_outputs=num_outputs,
overload=overload,
version=version,
graph=graph or self.graph_like,
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
self._nodes.append(node)
self._used_opsets.add((domain, version))

return node.outputs

Expand All @@ -74,20 +154,14 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
)
self._initializers.append(value)
if isinstance(self.graph_like, ir.Graph):
self.graph_like.register_initializer(value)
return value


# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = List[Tuple[str, Optional[int]]]


class Builder(Tape):
"""An extension of the tape that provides a more convenient API for constructing the IR."""

def __init__(self):
super().__init__()
self._used_opsets: UsedOpsets = []

def __getattr__(self, op_type: str) -> Any:
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)

Expand All @@ -101,20 +175,22 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
assert isinstance(outputs, int)
num_outputs = outputs

self._used_opsets.append((domain, version))
if num_outputs == 1:
value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
value = super().op(
op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version
)
if isinstance(outputs, Sequence):
value.name = outputs[0]
return value
values = super().op_multi_output(
op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
op_type,
inputs=inputs,
attributes=kwargs,
domain=domain,
version=version,
num_outputs=num_outputs,
)
if isinstance(outputs, Sequence):
for value, name in zip(values, outputs):
value.name = name
return values

@property
def used_opsets(self) -> UsedOpsets:
return self._used_opsets
Loading
Loading