diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 40622fd9b1..04b5574c0b 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -8,6 +8,7 @@ "traversal", "convenience", "external_data", + "tape", # IR classes "Tensor", "ExternalTensor", @@ -80,7 +81,7 @@ "save", ] -from onnxscript.ir import convenience, external_data, passes, serde, traversal +from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal from onnxscript.ir._convenience._constructors import node, tensor from onnxscript.ir._core import ( Attr, diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index ddb0e80309..e13a3fa978 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1135,7 +1135,7 @@ def __init__( num_outputs: int | None = None, outputs: Sequence[Value] | None = None, version: int | None = None, - graph: Graph | None = None, + graph: Graph | Function | None = None, name: str | None = None, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, @@ -1187,7 +1187,7 @@ def __init__( self._version: int | None = version self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = metadata_props - self._graph: Graph | None = graph + self._graph: Graph | Function | None = graph self.doc_string = doc_string # Add the node as a use of the inputs @@ -1432,11 +1432,11 @@ def metadata_props(self) -> dict[str, str]: return self._metadata_props @property - def graph(self) -> Graph | None: + def graph(self) -> Graph | Function | None: return self._graph @graph.setter - def graph(self, value: Graph | None) -> None: + def graph(self, value: Graph | Function | None) -> None: self._graph = value def op_identifier(self) -> _protocols.OperatorIdentifier: @@ -2162,7 +2162,7 @@ def sort(self) -> None: This sort is stable. It preserves the original order as much as possible. - Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort + Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort Raises: ValueError: If the graph contains a cycle, making topological sorting impossible. @@ -2170,7 +2170,7 @@ def sort(self) -> None: # Obtain all nodes from the graph and its subgraphs for sorting nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self)) # Store the sorted nodes of each subgraph - sorted_nodes_by_graph: dict[Graph, list[Node]] = { + sorted_nodes_by_graph: dict[Graph | Function, list[Node]] = { graph: [] for graph in {node.graph for node in nodes if node.graph is not None} } # TODO: Explain why we need to store direct predecessors and children and why @@ -2193,7 +2193,7 @@ def add_predecessor(child: Node, predecessor: Node | None) -> None: node_depth[predecessor] += 1 # 1. Build the direct predecessors of each node and the depth of each node - # for sorting topolocally using Kahn's algorithm. + # for sorting topologically using Kahn's algorithm. # Note that when a node contains graph attributes (aka. has subgraphs), # we consider all nodes in the subgraphs *predecessors* of this node. This # way we ensure the implicit dependencies of the subgraphs are captured @@ -2718,11 +2718,11 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: """ self._graph.remove(nodes, safe=safe) - def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None: + def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes after the given node in O(#new_nodes) time.""" self._graph.insert_after(node, new_nodes) - def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None: + def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes before the given node in O(#new_nodes) time.""" self._graph.insert_before(node, new_nodes) diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index 9d038602fc..fbc2c7c054 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -320,11 +320,15 @@ def remove(self, node: NodeProtocol, /) -> None: """Remove a node from the graph.""" ... - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_after( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes after the given node.""" ... - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_before( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes before the given node.""" ... @@ -589,11 +593,15 @@ def remove(self, node: NodeProtocol, /) -> None: """Remove a node from the function.""" ... - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_after( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes after the given node.""" ... - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_before( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes before the given node.""" ... diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 752a52a243..0a63118d4f 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -2,26 +2,63 @@ # Licensed under the MIT License. """Convenience methods for constructing the IR.""" -# NOTE: This is a temporary solution for constructing the IR. It should be replaced -# with a more permanent solution in the future. - from __future__ import annotations -from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple +from typing import ( + Any, + Mapping, + Optional, + Sequence, + Tuple, +) from onnxscript import ir from onnxscript.ir import _convenience +# A type representing the domains/versions used in creating nodes in IR. +UsedOpsets = set[Tuple[str, Optional[int]]] + + +class Tape: + """Tape class. + + A tape is a recorder that collects nodes and initializers that are created so + that they can be used for creating a graph. + + Example:: + from onnxscript import ir + + tape = ir.tape.Tape() + a = tape.initializer(ir.tensor([1, 2, 3], name="a")) + b: ir.Value = ... + c: ir.Value = ... + x = tape.op("Add", [a, b], attributes={"alpha": 1.0}) + y = tape.op("Mul", [x, c], attributes={"beta": 2.0}) + model = ir.Model( + graph := ir.Graph( + inputs=[b, c], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers + opset_imports={"": 20}, + ), + ir_version=10, + ) -class Tape(Iterable[ir.Node]): - """A tape for recording nodes that are created.""" + Attributes: + graph_like: The graph to append the new nodes and initializers to. When + it is None, the nodes and initializers are creating without owned by a graph. + Initializers will not be added to functions because it is not supported by ONNX. + """ - def __init__(self) -> None: + def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None: self._nodes: list[ir.Node] = [] self._initializers: list[ir.Value] = [] + self._used_opsets: UsedOpsets = set() + self.graph_like = graph_like - def __iter__(self) -> Iterator[ir.Node]: - return iter(self._nodes) + def __repr__(self) -> str: + return f"Tape(nodes={self._nodes}, initializers={self._initializers})" @property def nodes(self) -> Sequence[ir.Node]: @@ -31,19 +68,43 @@ def nodes(self) -> Sequence[ir.Node]: def initializers(self) -> Sequence[ir.Value]: return tuple(self._initializers) + @property + def used_opsets(self) -> UsedOpsets: + return self._used_opsets + def op( self, op_type: str, inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, + *, domain: str = "", + overload: str = "", + version: int | None = None, + graph: ir.Graph | None = None, + name: str | None = None, + doc_string: str | None = None, + metadata_props: dict[str, str] | None = None, ) -> ir.Value: if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () else: attrs = _convenience.convert_attributes(attributes) - node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1) + node = ir.Node( + domain, + op_type, + inputs, + attributes=attrs, + num_outputs=1, + overload=overload, + version=version, + graph=graph or self.graph_like, + name=name, + doc_string=doc_string, + metadata_props=metadata_props, + ) self._nodes.append(node) + self._used_opsets.add((domain, version)) return node.outputs[0] @@ -55,13 +116,32 @@ def op_multi_output( *, num_outputs: int, domain: str = "", + overload: str = "", + version: int | None = None, + graph: ir.Graph | None = None, + name: str | None = None, + doc_string: str | None = None, + metadata_props: dict[str, str] | None = None, ) -> Sequence[ir.Value]: if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () else: attrs = _convenience.convert_attributes(attributes) - node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs) + node = ir.Node( + domain, + op_type, + inputs, + attributes=attrs, + num_outputs=num_outputs, + overload=overload, + version=version, + graph=graph or self.graph_like, + name=name, + doc_string=doc_string, + metadata_props=metadata_props, + ) self._nodes.append(node) + self._used_opsets.add((domain, version)) return node.outputs @@ -74,20 +154,14 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir. name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor ) self._initializers.append(value) + if isinstance(self.graph_like, ir.Graph): + self.graph_like.register_initializer(value) return value -# A type representing the domains/versions used in creating nodes in IR. -UsedOpsets = List[Tuple[str, Optional[int]]] - - class Builder(Tape): """An extension of the tape that provides a more convenient API for constructing the IR.""" - def __init__(self): - super().__init__() - self._used_opsets: UsedOpsets = [] - def __getattr__(self, op_type: str) -> Any: return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) @@ -101,20 +175,22 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, assert isinstance(outputs, int) num_outputs = outputs - self._used_opsets.append((domain, version)) if num_outputs == 1: - value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain) + value = super().op( + op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version + ) if isinstance(outputs, Sequence): value.name = outputs[0] return value values = super().op_multi_output( - op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs + op_type, + inputs=inputs, + attributes=kwargs, + domain=domain, + version=version, + num_outputs=num_outputs, ) if isinstance(outputs, Sequence): for value, name in zip(values, outputs): value.name = name return values - - @property - def used_opsets(self) -> UsedOpsets: - return self._used_opsets diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py new file mode 100644 index 0000000000..922c6d7eaa --- /dev/null +++ b/onnxscript/ir/_tape_test.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +from onnxscript import ir + + +class TestTape(unittest.TestCase): + def test_op(self): + # Create a simple ONNX model with shape inference + # Define the model + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + _ = tape.op("Add", inputs=inputs) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add"]) + + def test_initializers(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + # Shape and type are not explicitly set for the initializer but it should still work + initializer = tape.initializer( + ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT), name="initializer" + ) + val_add = tape.op("Add", inputs=inputs) + _ = tape.op("Mul", inputs=[val_add, initializer]) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add", "Mul"]) + self.assertEqual(tape.initializers, (initializer,)) + + def test_op_multi_out(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + out1, out2, out3 = tape.op_multi_output("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking + _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) + + self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py index 3fc08400e3..da67b4c1a7 100644 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ b/onnxscript/ir/passes/common/shape_inference_test.py @@ -23,19 +23,21 @@ def test_pass(self): ), ] - add_node = ir.Node("", "Add", inputs=inputs) + tape = ir.tape.Tape() + + output = tape.op("Add", inputs=inputs) model = ir.Model( ir.Graph( inputs=inputs, - outputs=add_node.outputs, - nodes=[add_node], + outputs=[output], + nodes=tape.nodes, opset_imports={"": 20}, ), ir_version=10, ) - self.assertIsNone(add_node.outputs[0].shape) - self.assertIsNone(add_node.outputs[0].dtype) + self.assertIsNone(output.shape) + self.assertIsNone(output.dtype) # Perform shape inference result = shape_inference.ShapeInferencePass()(model) @@ -62,30 +64,30 @@ def test_pass_with_initializers(self): ), ] + tape = ir.tape.Tape() + # Shape and type are not explicitly set for the initializer but it should still work initializer = ir.Value( name="initializer", const_value=ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT) ) - - add_node = ir.Node("", "Add", inputs=[*inputs]) - mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], initializer]) + val_add = tape.op("Add", inputs=inputs) + val_mul = tape.op("Mul", inputs=[val_add, initializer]) model = ir.Model( - graph := ir.Graph( + ir.Graph( inputs=inputs, - outputs=mul_node.outputs, - nodes=[add_node, mul_node], + outputs=[val_mul], + nodes=tape.nodes, opset_imports={"": 20}, + initializers=[inputs[1], initializer], ), ir_version=10, ) - graph.register_initializer(inputs[1]) - graph.register_initializer(initializer) - self.assertIsNone(add_node.outputs[0].shape) - self.assertIsNone(add_node.outputs[0].dtype) - self.assertIsNone(mul_node.outputs[0].shape) - self.assertIsNone(mul_node.outputs[0].dtype) + self.assertIsNone(val_add.shape) + self.assertIsNone(val_add.dtype) + self.assertIsNone(val_mul.shape) + self.assertIsNone(val_mul.dtype) self.assertIsNone(initializer.shape) self.assertIsNone(initializer.dtype) @@ -128,10 +130,10 @@ def test_pass_with_initializers(self): ) # Check that the original model is not modified - self.assertIsNone(add_node.outputs[0].shape) - self.assertIsNone(add_node.outputs[0].dtype) - self.assertIsNone(mul_node.outputs[0].shape) - self.assertIsNone(mul_node.outputs[0].dtype) + self.assertIsNone(val_add.shape) + self.assertIsNone(val_add.dtype) + self.assertIsNone(val_mul.shape) + self.assertIsNone(val_mul.dtype) self.assertEqual(len(model.graph.inputs), 2) self.assertEqual(len(model.graph.initializers), 2) self.assertIs(model.graph.initializers["input_b"].const_value, inputs[1].const_value) diff --git a/onnxscript/ir/tape.py b/onnxscript/ir/tape.py new file mode 100644 index 0000000000..9270dcdcec --- /dev/null +++ b/onnxscript/ir/tape.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Taping module to facilitate building IR graphs.""" + +# NOTE: Be *selective* about what this module exports because it is part of the public API. + +from __future__ import annotations + +__all__ = [ + "Tape", +] + +from onnxscript.ir._tape import Tape + +Tape.__module__ = __name__