From 9a078c93bdd1f74c0d315f9a4b3be48a017ebaaf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 12:08:42 -0700 Subject: [PATCH 1/2] [IR] Implement model.graphs() --- onnxscript/ir/_core.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index b408898f71..948d29ac62 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -50,6 +50,7 @@ _name_authority, _protocols, _type_casting, + traversal, ) if typing.TYPE_CHECKING: @@ -2561,6 +2562,25 @@ def __repr__(self) -> str: graph={textwrap.indent(repr(self.graph), " " * 4).strip()} )""" + def graphs(self) -> Iterable[Graph]: + """Get all graphs and subgraphs in the model. + + This is a convenience method to traverse the model. Consider using + `onnxscript.ir.traversal.RecursiveGraphIterator` for more advanced + traversals on nodes. + """ + # NOTE(justinchuby): Given + # (1) how useful the method is + # (2) I couldn't find an appropriate name for it in `traversal.py` + # (3) Users familiar with onnxruntime optimization tools expect this method + # I created this method as a core method instead of an iterator in + # `traversal.py`. + seen_graphs: set[Graph] = set() + for node in traversal.RecursiveGraphIterator(self.graph): + if node.graph is not None and node.graph not in seen_graphs: + seen_graphs.add(node.graph) + yield node.graph + class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable): """IR functions. From bfb6d052b33d1b0e287df7427c05e6249577b90b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 12:13:11 -0700 Subject: [PATCH 2/2] test --- onnxscript/ir/_core.py | 3 +- onnxscript/ir/_core_test.py | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 948d29ac62..4021b06635 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -50,7 +50,6 @@ _name_authority, _protocols, _type_casting, - traversal, ) if typing.TYPE_CHECKING: @@ -2576,7 +2575,7 @@ def graphs(self) -> Iterable[Graph]: # I created this method as a core method instead of an iterator in # `traversal.py`. seen_graphs: set[Graph] = set() - for node in traversal.RecursiveGraphIterator(self.graph): + for node in onnxscript.ir.traversal.RecursiveGraphIterator(self.graph): if node.graph is not None and node.graph not in seen_graphs: seen_graphs.add(node.graph) yield node.graph diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 9b6cc94f6f..b20a17681c 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -1152,6 +1152,61 @@ def test_topological_sort_subgraph(self): ) +class ModelTest(unittest.TestCase): + def test_graphs_returns_all_subgraphs(self): + # main_graph: nodes=[a,b,c,d,>,if], edges=[(a,>),(b,>),(>,if)], subgraphs={if:[then_graph,else_graph]} + # then_graph: nodes=[sub], edges=[(c,sub),(d,sub)] + # else_graph: nodes=[add], edges=[(c,add),(d,add)] + v0 = _core.Value(name="va") + v1 = _core.Value(name="vb") + v2 = _core.Value(name="vc") + v3 = _core.Value(name="vd") + node0 = _core.Node("", "a", inputs=(v0,), num_outputs=1) + node1 = _core.Node("", "b", inputs=(v1,), num_outputs=1) + node2 = _core.Node("", "c", inputs=(v2,), num_outputs=1) + node3 = _core.Node("", "d", inputs=(v3,), num_outputs=1) + node4 = _core.Node( + "", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 + ) + node5 = _core.Node( + "", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 + ) + node6 = _core.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1) + then_graph = _core.Graph( + inputs=(node2.outputs[0], node3.outputs[0]), + outputs=(node4.outputs[0],), + nodes=(node4,), + name="then_graph", + ) + else_graph = _core.Graph( + inputs=(node2.outputs[0], node3.outputs[0]), + outputs=(node5.outputs[0],), + nodes=(node5,), + name="else_graph", + ) + node7 = _core.Node( + "", + "if", + inputs=(node6.outputs[0],), + num_outputs=1, + attributes=[ + ir.AttrGraph("then_branch", then_graph), + ir.AttrGraph("else_branch", else_graph), + ], + ) + main_graph = _core.Graph( + inputs=(v0, v1, v2, v3), + outputs=(node7.outputs[0],), + nodes=(node0, node1, node2, node6, node7), + name="main_graph", + ) + model = _core.Model(main_graph, ir_version=10) + self.assertEqual( + tuple(model.graphs()), + (main_graph, then_graph, else_graph), + ) + + class TypeTest(unittest.TestCase): @parameterized.parameterized.expand( [