From f976ed1faa96d91d3aa102bfc5697a60ea2bd731 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 10:19:34 -0700 Subject: [PATCH] Add test for TopologicalSortPass on functions Add a test for `TopologicalSortPass` on functions in a model in `onnxscript/ir/passes/common/topological_sort_test.py`. * Add `test_topological_sort_on_functions` function to verify `TopologicalSortPass` on functions. * Create a function with unsorted nodes and a model containing the function. * Apply `TopologicalSortPass` and verify that the nodes in the function are sorted correctly. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/microsoft/onnxscript?shareId=XXXX-XXXX-XXXX-XXXX). --- .../ir/passes/common/topological_sort_test.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py index ca9d1377f0..8680761f1e 100644 --- a/onnxscript/ir/passes/common/topological_sort_test.py +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -45,6 +45,41 @@ def test_topological_sort_modified_false(self): (self.node_a, self.node_b, self.node_c), ) + def test_topological_sort_on_functions(self): + """Test that TopologicalSortPass works on functions in a model.""" + # Create a function with unsorted nodes + func_graph = ir.Graph( + inputs=self.node_a.inputs, + outputs=self.node_c.outputs, + nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes + ) + function = ir.Function( + domain="test_domain", + name="test_function", + graph=func_graph, + attributes=[], + ) + + # Create a model with the function + graph = ir.Graph( + inputs=[], + outputs=[], + nodes=[], + name="test_graph", + ) + model = ir.Model(graph, ir_version=10, functions=[function]) + + # Apply the TopologicalSortPass + result = topological_sort.TopologicalSortPass()(model) + + # Verify that the nodes in the function are sorted + sorted_func_nodes = (self.node_a, self.node_b, self.node_c) + self.assertTrue(result.modified) + self.assertEqual( + tuple(result.model.functions[function.identifier()]), + sorted_func_nodes, + ) + if __name__ == "__main__": unittest.main()