From 573ada391faea5d1b3a933a95570e462742fa9b8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 11:57:15 -0700 Subject: [PATCH 01/14] tape --- onnxscript/ir/_tape.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 752a52a243..1dd3719584 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -36,7 +36,14 @@ def op( op_type: str, inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, - domain: str = "", + *, + domain: str = "", + overload: str = "", + version: int | None = None, + graph: ir.Graph | None = None, + name: str | None = None, + doc_string: str | None = None, + metadata_props: dict[str, str] | None = None, ) -> ir.Value: if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () @@ -53,8 +60,14 @@ def op_multi_output( inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, *, - num_outputs: int, - domain: str = "", + domain: str = "", + overload: str = "", + num_outputs: int | None = None, + version: int | None = None, + graph: ir.Graph | None = None, + name: str | None = None, + doc_string: str | None = None, + metadata_props: dict[str, str] | None = None, ) -> Sequence[ir.Value]: if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () From 903e48ef020be7df3966219a0c9848bb265df9c5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 12:50:34 -0700 Subject: [PATCH 02/14] Create a public tape --- onnxscript/ir/_tape.py | 32 ++++++++++++-- .../ir/passes/common/shape_inference_test.py | 44 ++++++++++--------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 1dd3719584..c385a58f19 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -37,7 +37,7 @@ def op( inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, *, - domain: str = "", + domain: str = "", overload: str = "", version: int | None = None, graph: ir.Graph | None = None, @@ -49,7 +49,19 @@ def op( attrs: Sequence[ir.Attr | ir.RefAttr] = () else: attrs = _convenience.convert_attributes(attributes) - node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1) + node = ir.Node( + domain, + op_type, + inputs, + attributes=attrs, + num_outputs=1, + overload=overload, + version=version, + graph=graph, + name=name, + doc_string=doc_string, + metadata_props=metadata_props, + ) self._nodes.append(node) return node.outputs[0] @@ -60,7 +72,7 @@ def op_multi_output( inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, *, - domain: str = "", + domain: str = "", overload: str = "", num_outputs: int | None = None, version: int | None = None, @@ -73,7 +85,19 @@ def op_multi_output( attrs: Sequence[ir.Attr | ir.RefAttr] = () else: attrs = _convenience.convert_attributes(attributes) - node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs) + node = ir.Node( + domain, + op_type, + inputs, + attributes=attrs, + num_outputs=num_outputs, + overload=overload, + version=version, + graph=graph, + name=name, + doc_string=doc_string, + metadata_props=metadata_props, + ) self._nodes.append(node) return node.outputs diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py index 3fc08400e3..9acfe05f24 100644 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ b/onnxscript/ir/passes/common/shape_inference_test.py @@ -7,6 +7,7 @@ import numpy as np from onnxscript import ir +from onnxscript.ir import building from onnxscript.ir.passes.common import shape_inference @@ -23,19 +24,21 @@ def test_pass(self): ), ] - add_node = ir.Node("", "Add", inputs=inputs) + tape = building.Tape() + + output = tape.op("Add", inputs=inputs) model = ir.Model( ir.Graph( inputs=inputs, - outputs=add_node.outputs, - nodes=[add_node], + outputs=[output], + nodes=tape.nodes, opset_imports={"": 20}, ), ir_version=10, ) - self.assertIsNone(add_node.outputs[0].shape) - self.assertIsNone(add_node.outputs[0].dtype) + self.assertIsNone(output.shape) + self.assertIsNone(output.dtype) # Perform shape inference result = shape_inference.ShapeInferencePass()(model) @@ -62,19 +65,20 @@ def test_pass_with_initializers(self): ), ] + tape = building.Tape() + # Shape and type are not explicitly set for the initializer but it should still work - initializer = ir.Value( - name="initializer", const_value=ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT) + initializer = tape.initializer( + ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT, name="initializer") ) - - add_node = ir.Node("", "Add", inputs=[*inputs]) - mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], initializer]) + val_add = tape.op("Add", inputs=[*inputs]) + val_mul = tape.op("Mul", inputs=[val_add, initializer]) model = ir.Model( graph := ir.Graph( inputs=inputs, - outputs=mul_node.outputs, - nodes=[add_node, mul_node], + outputs=[val_mul], + nodes=tape.nodes, opset_imports={"": 20}, ), ir_version=10, @@ -82,10 +86,10 @@ def test_pass_with_initializers(self): graph.register_initializer(inputs[1]) graph.register_initializer(initializer) - self.assertIsNone(add_node.outputs[0].shape) - self.assertIsNone(add_node.outputs[0].dtype) - self.assertIsNone(mul_node.outputs[0].shape) - self.assertIsNone(mul_node.outputs[0].dtype) + self.assertIsNone(val_add.shape) + self.assertIsNone(val_add.dtype) + self.assertIsNone(val_mul.shape) + self.assertIsNone(val_mul.dtype) self.assertIsNone(initializer.shape) self.assertIsNone(initializer.dtype) @@ -128,10 +132,10 @@ def test_pass_with_initializers(self): ) # Check that the original model is not modified - self.assertIsNone(add_node.outputs[0].shape) - self.assertIsNone(add_node.outputs[0].dtype) - self.assertIsNone(mul_node.outputs[0].shape) - self.assertIsNone(mul_node.outputs[0].dtype) + self.assertIsNone(val_add.shape) + self.assertIsNone(val_add.dtype) + self.assertIsNone(val_mul.shape) + self.assertIsNone(val_mul.dtype) self.assertEqual(len(model.graph.inputs), 2) self.assertEqual(len(model.graph.initializers), 2) self.assertIs(model.graph.initializers["input_b"].const_value, inputs[1].const_value) From 627d30f4fb1523df8f874eef50f93ad490c16e0f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 12:58:15 -0700 Subject: [PATCH 03/14] tape --- onnxscript/ir/__init__.py | 3 +- onnxscript/ir/_tape.py | 30 ++++++++++++++++--- .../ir/passes/common/shape_inference_test.py | 5 ++-- onnxscript/ir/tape.py | 15 ++++++++++ 4 files changed, 45 insertions(+), 8 deletions(-) create mode 100644 onnxscript/ir/tape.py diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index a9918e9713..d4c444ac2c 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -8,6 +8,7 @@ "traversal", "convenience", "external_data", + "tape", # IR classes "Tensor", "ExternalTensor", @@ -79,7 +80,7 @@ "save", ] -from onnxscript.ir import convenience, external_data, passes, serde, traversal +from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal from onnxscript.ir._convenience import tensor from onnxscript.ir._core import ( Attr, diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index c385a58f19..607a83d896 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -2,9 +2,6 @@ # Licensed under the MIT License. """Convenience methods for constructing the IR.""" -# NOTE: This is a temporary solution for constructing the IR. It should be replaced -# with a more permanent solution in the future. - from __future__ import annotations from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple @@ -14,7 +11,32 @@ class Tape(Iterable[ir.Node]): - """A tape for recording nodes that are created.""" + """Tape class. + + A tape is a recorder that collects nodes and initializers that are created so + that they can be used for creating a graph. + + Example:: + from onnxscript import ir + + tape = Tape() + a = tape.initializer(ir.tensor([1, 2, 3], name="a")) + b: ir.Value = ... + c: ir.Value = ... + x = tape.op("Add", [a, b], attributes={"alpha": 1.0}) + y = tape.op("Mul", [x, c], attributes={"beta": 2.0}) + model = ir.Model( + graph := ir.Graph( + inputs=[b, c], + outputs=[y], + nodes=tape.nodes, + opset_imports={"": 20}, + ), + ir_version=10, + ) + for initializer in tape.initializers: + graph.register_initializer(initializer) + """ def __init__(self) -> None: self._nodes: list[ir.Node] = [] diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py index 9acfe05f24..db2b9e43cf 100644 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ b/onnxscript/ir/passes/common/shape_inference_test.py @@ -7,7 +7,6 @@ import numpy as np from onnxscript import ir -from onnxscript.ir import building from onnxscript.ir.passes.common import shape_inference @@ -24,7 +23,7 @@ def test_pass(self): ), ] - tape = building.Tape() + tape = ir.tape.Tape() output = tape.op("Add", inputs=inputs) @@ -65,7 +64,7 @@ def test_pass_with_initializers(self): ), ] - tape = building.Tape() + tape = ir.tape.Tape() # Shape and type are not explicitly set for the initializer but it should still work initializer = tape.initializer( diff --git a/onnxscript/ir/tape.py b/onnxscript/ir/tape.py new file mode 100644 index 0000000000..9270dcdcec --- /dev/null +++ b/onnxscript/ir/tape.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Taping module to facilitate building IR graphs.""" + +# NOTE: Be *selective* about what this module exports because it is part of the public API. + +from __future__ import annotations + +__all__ = [ + "Tape", +] + +from onnxscript.ir._tape import Tape + +Tape.__module__ = __name__ From 1a43ace6e0e1273fb29011cab231a23abf0c5e2d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 13:05:39 -0700 Subject: [PATCH 04/14] fix tests --- onnxscript/ir/_tape.py | 18 +++++++++++++----- .../ir/passes/common/shape_inference_test.py | 9 ++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 607a83d896..e15e2daf8c 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -4,13 +4,22 @@ from __future__ import annotations -from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple +from typing import ( + Any, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, +) from onnxscript import ir -from onnxscript.ir import _convenience +from onnxscript.ir import _convenience, _core -class Tape(Iterable[ir.Node]): +class Tape(Iterable[_core.Node]): """Tape class. A tape is a recorder that collects nodes and initializers that are created so @@ -30,12 +39,11 @@ class Tape(Iterable[ir.Node]): inputs=[b, c], outputs=[y], nodes=tape.nodes, + initializers=tape.initializers opset_imports={"": 20}, ), ir_version=10, ) - for initializer in tape.initializers: - graph.register_initializer(initializer) """ def __init__(self) -> None: diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py index db2b9e43cf..44b40706ff 100644 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ b/onnxscript/ir/passes/common/shape_inference_test.py @@ -67,23 +67,22 @@ def test_pass_with_initializers(self): tape = ir.tape.Tape() # Shape and type are not explicitly set for the initializer but it should still work - initializer = tape.initializer( - ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT, name="initializer") + initializer = ir.Value( + name="initializer", const_value=ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT) ) val_add = tape.op("Add", inputs=[*inputs]) val_mul = tape.op("Mul", inputs=[val_add, initializer]) model = ir.Model( - graph := ir.Graph( + ir.Graph( inputs=inputs, outputs=[val_mul], nodes=tape.nodes, opset_imports={"": 20}, + initializers=[inputs[1], initializer], ), ir_version=10, ) - graph.register_initializer(inputs[1]) - graph.register_initializer(initializer) self.assertIsNone(val_add.shape) self.assertIsNone(val_add.dtype) From 1031e429f01763c6c6bb3e110936fe651269b198 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 13:06:44 -0700 Subject: [PATCH 05/14] inputs --- onnxscript/ir/passes/common/shape_inference_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py index 44b40706ff..da67b4c1a7 100644 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ b/onnxscript/ir/passes/common/shape_inference_test.py @@ -70,7 +70,7 @@ def test_pass_with_initializers(self): initializer = ir.Value( name="initializer", const_value=ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT) ) - val_add = tape.op("Add", inputs=[*inputs]) + val_add = tape.op("Add", inputs=inputs) val_mul = tape.op("Mul", inputs=[val_add, initializer]) model = ir.Model( From fad31225db64370d3d2442a209d6127e902d3537 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 13:08:11 -0700 Subject: [PATCH 06/14] num_outputs --- onnxscript/ir/_tape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index e15e2daf8c..94ea752c15 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -102,9 +102,9 @@ def op_multi_output( inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, *, + num_outputs: int, domain: str = "", overload: str = "", - num_outputs: int | None = None, version: int | None = None, graph: ir.Graph | None = None, name: str | None = None, From 4e6284cbd66cecfd509b35bd9f82c725e23cf953 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 25 Mar 2025 09:22:28 -0700 Subject: [PATCH 07/14] example --- onnxscript/ir/_tape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 94ea752c15..c718feb2b5 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -28,7 +28,7 @@ class Tape(Iterable[_core.Node]): Example:: from onnxscript import ir - tape = Tape() + tape = ir.tape.Tape() a = tape.initializer(ir.tensor([1, 2, 3], name="a")) b: ir.Value = ... c: ir.Value = ... From 224c86f0cb6ea61d23e30e3d6e1efaa6bd262706 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 28 Mar 2025 10:06:18 -0700 Subject: [PATCH 08/14] used_opsets --- onnxscript/ir/_tape.py | 45 +++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index c718feb2b5..b646a63679 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -6,9 +6,6 @@ from typing import ( Any, - Iterable, - Iterator, - List, Mapping, Optional, Sequence, @@ -16,10 +13,13 @@ ) from onnxscript import ir -from onnxscript.ir import _convenience, _core +from onnxscript.ir import _convenience + +# A type representing the domains/versions used in creating nodes in IR. +UsedOpsets = set[Tuple[str, Optional[int]]] -class Tape(Iterable[_core.Node]): +class Tape: """Tape class. A tape is a recorder that collects nodes and initializers that are created so @@ -49,9 +49,10 @@ class Tape(Iterable[_core.Node]): def __init__(self) -> None: self._nodes: list[ir.Node] = [] self._initializers: list[ir.Value] = [] + self._used_opsets: UsedOpsets = set() - def __iter__(self) -> Iterator[ir.Node]: - return iter(self._nodes) + def __repr__(self) -> str: + return f"Tape(nodes={self._nodes}, initializers={self._initializers})" @property def nodes(self) -> Sequence[ir.Node]: @@ -61,6 +62,10 @@ def nodes(self) -> Sequence[ir.Node]: def initializers(self) -> Sequence[ir.Value]: return tuple(self._initializers) + @property + def used_opsets(self) -> UsedOpsets: + return self._used_opsets + def op( self, op_type: str, @@ -93,6 +98,7 @@ def op( metadata_props=metadata_props, ) self._nodes.append(node) + self._used_opsets.add((domain, version)) return node.outputs[0] @@ -129,6 +135,7 @@ def op_multi_output( metadata_props=metadata_props, ) self._nodes.append(node) + self._used_opsets.add((domain, version)) return node.outputs @@ -144,17 +151,9 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir. return value -# A type representing the domains/versions used in creating nodes in IR. -UsedOpsets = List[Tuple[str, Optional[int]]] - - class Builder(Tape): """An extension of the tape that provides a more convenient API for constructing the IR.""" - def __init__(self): - super().__init__() - self._used_opsets: UsedOpsets = [] - def __getattr__(self, op_type: str) -> Any: return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) @@ -168,20 +167,22 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, assert isinstance(outputs, int) num_outputs = outputs - self._used_opsets.append((domain, version)) if num_outputs == 1: - value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain) + value = super().op( + op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version + ) if isinstance(outputs, Sequence): value.name = outputs[0] return value values = super().op_multi_output( - op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs + op_type, + inputs=inputs, + attributes=kwargs, + domain=domain, + version=version, + num_outputs=num_outputs, ) if isinstance(outputs, Sequence): for value, name in zip(values, outputs): value.name = name return values - - @property - def used_opsets(self) -> UsedOpsets: - return self._used_opsets From 86266bdbd340990cf92c5740a679a2072e723de6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 28 Mar 2025 10:11:47 -0700 Subject: [PATCH 09/14] Take a graph --- onnxscript/ir/_core.py | 8 ++++---- onnxscript/ir/_tape.py | 14 +++++++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index ddb0e80309..ca0eb5d1f8 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1135,7 +1135,7 @@ def __init__( num_outputs: int | None = None, outputs: Sequence[Value] | None = None, version: int | None = None, - graph: Graph | None = None, + graph: Graph | Function | None = None, name: str | None = None, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, @@ -1187,7 +1187,7 @@ def __init__( self._version: int | None = version self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = metadata_props - self._graph: Graph | None = graph + self._graph: Graph | Function | None = graph self.doc_string = doc_string # Add the node as a use of the inputs @@ -1432,11 +1432,11 @@ def metadata_props(self) -> dict[str, str]: return self._metadata_props @property - def graph(self) -> Graph | None: + def graph(self) -> Graph | Function | None: return self._graph @graph.setter - def graph(self, value: Graph | None) -> None: + def graph(self, value: Graph | Function | None) -> None: self._graph = value def op_identifier(self) -> _protocols.OperatorIdentifier: diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index b646a63679..0a63118d4f 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -44,12 +44,18 @@ class Tape: ), ir_version=10, ) + + Attributes: + graph_like: The graph to append the new nodes and initializers to. When + it is None, the nodes and initializers are creating without owned by a graph. + Initializers will not be added to functions because it is not supported by ONNX. """ - def __init__(self) -> None: + def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None: self._nodes: list[ir.Node] = [] self._initializers: list[ir.Value] = [] self._used_opsets: UsedOpsets = set() + self.graph_like = graph_like def __repr__(self) -> str: return f"Tape(nodes={self._nodes}, initializers={self._initializers})" @@ -92,7 +98,7 @@ def op( num_outputs=1, overload=overload, version=version, - graph=graph, + graph=graph or self.graph_like, name=name, doc_string=doc_string, metadata_props=metadata_props, @@ -129,7 +135,7 @@ def op_multi_output( num_outputs=num_outputs, overload=overload, version=version, - graph=graph, + graph=graph or self.graph_like, name=name, doc_string=doc_string, metadata_props=metadata_props, @@ -148,6 +154,8 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir. name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor ) self._initializers.append(value) + if isinstance(self.graph_like, ir.Graph): + self.graph_like.register_initializer(value) return value From 1741b71b63c6fbd21ab647e02bc1537612599d19 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 28 Mar 2025 10:20:27 -0700 Subject: [PATCH 10/14] test --- onnxscript/ir/_tape_test.py | 78 +++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 onnxscript/ir/_tape_test.py diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py new file mode 100644 index 0000000000..22eab82ba8 --- /dev/null +++ b/onnxscript/ir/_tape_test.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np + +from onnxscript import ir + + +class TestTape(unittest.TestCase): + def test_op(self): + # Create a simple ONNX model with shape inference + # Define the model + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + _output = tape.op("Add", inputs=inputs) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add"]) + + def test_initializers(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + # Shape and type are not explicitly set for the initializer but it should still work + initializer = tape.initializer( + ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT), name="initializer" + ) + val_add = tape.op("Add", inputs=inputs) + _val_mul = tape.op("Mul", inputs=[val_add, initializer]) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add", "Mul"]) + self.assertEqual(tape.initializers, (initializer,)) + + def test_op_multi_out(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + out1, out2, out3 = tape.op_multi_output("SomeOp", inputs=inputs, num_outputs=3) + _result = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) + + self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) + + +if __name__ == "__main__": + unittest.main() From f6120f08d6829d5e7e1f7e0871d8c4058c88ce9c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 31 Mar 2025 11:40:25 -0700 Subject: [PATCH 11/14] lint --- onnxscript/ir/_tape_test.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py index 22eab82ba8..922c6d7eaa 100644 --- a/onnxscript/ir/_tape_test.py +++ b/onnxscript/ir/_tape_test.py @@ -4,8 +4,6 @@ import unittest -import numpy as np - from onnxscript import ir @@ -24,7 +22,7 @@ def test_op(self): tape = ir.tape.Tape() - _output = tape.op("Add", inputs=inputs) + _ = tape.op("Add", inputs=inputs) self.assertEqual([n.op_type for n in tape.nodes], ["Add"]) @@ -48,7 +46,7 @@ def test_initializers(self): ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT), name="initializer" ) val_add = tape.op("Add", inputs=inputs) - _val_mul = tape.op("Mul", inputs=[val_add, initializer]) + _ = tape.op("Mul", inputs=[val_add, initializer]) self.assertEqual([n.op_type for n in tape.nodes], ["Add", "Mul"]) self.assertEqual(tape.initializers, (initializer,)) @@ -68,8 +66,8 @@ def test_op_multi_out(self): tape = ir.tape.Tape() - out1, out2, out3 = tape.op_multi_output("SomeOp", inputs=inputs, num_outputs=3) - _result = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) + out1, out2, out3 = tape.op_multi_output("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking + _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) From 081ae344c886809a43816dd219655800dc6955f7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Apr 2025 09:33:41 -0700 Subject: [PATCH 12/14] typing --- onnxscript/ir/_core.py | 4 ++-- onnxscript/ir/_protocols.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index ca0eb5d1f8..3f2982ed57 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -2718,11 +2718,11 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: """ self._graph.remove(nodes, safe=safe) - def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None: + def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes after the given node in O(#new_nodes) time.""" self._graph.insert_after(node, new_nodes) - def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None: + def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes before the given node in O(#new_nodes) time.""" self._graph.insert_before(node, new_nodes) diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index 9d038602fc..975cda160a 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -320,11 +320,11 @@ def remove(self, node: NodeProtocol, /) -> None: """Remove a node from the graph.""" ... - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: """Insert new nodes after the given node.""" ... - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: """Insert new nodes before the given node.""" ... @@ -589,11 +589,11 @@ def remove(self, node: NodeProtocol, /) -> None: """Remove a node from the function.""" ... - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: """Insert new nodes after the given node.""" ... - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: """Insert new nodes before the given node.""" ... From 17457618fe1d36da2d642115270a9671303cce83 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Apr 2025 09:39:22 -0700 Subject: [PATCH 13/14] type --- onnxscript/ir/_core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 3f2982ed57..e13a3fa978 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -2162,7 +2162,7 @@ def sort(self) -> None: This sort is stable. It preserves the original order as much as possible. - Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort + Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort Raises: ValueError: If the graph contains a cycle, making topological sorting impossible. @@ -2170,7 +2170,7 @@ def sort(self) -> None: # Obtain all nodes from the graph and its subgraphs for sorting nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self)) # Store the sorted nodes of each subgraph - sorted_nodes_by_graph: dict[Graph, list[Node]] = { + sorted_nodes_by_graph: dict[Graph | Function, list[Node]] = { graph: [] for graph in {node.graph for node in nodes if node.graph is not None} } # TODO: Explain why we need to store direct predecessors and children and why @@ -2193,7 +2193,7 @@ def add_predecessor(child: Node, predecessor: Node | None) -> None: node_depth[predecessor] += 1 # 1. Build the direct predecessors of each node and the depth of each node - # for sorting topolocally using Kahn's algorithm. + # for sorting topologically using Kahn's algorithm. # Note that when a node contains graph attributes (aka. has subgraphs), # we consider all nodes in the subgraphs *predecessors* of this node. This # way we ensure the implicit dependencies of the subgraphs are captured From f2da6348f74e1898e2bd1a3fb10ad372d589f7df Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Apr 2025 09:41:03 -0700 Subject: [PATCH 14/14] format --- onnxscript/ir/_protocols.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index 975cda160a..fbc2c7c054 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -320,11 +320,15 @@ def remove(self, node: NodeProtocol, /) -> None: """Remove a node from the graph.""" ... - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: + def insert_after( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes after the given node.""" ... - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: + def insert_before( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes before the given node.""" ... @@ -589,11 +593,15 @@ def remove(self, node: NodeProtocol, /) -> None: """Remove a node from the function.""" ... - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: + def insert_after( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes after the given node.""" ... - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /) -> None: + def insert_before( + self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / + ) -> None: """Insert new nodes before the given node.""" ...