Skip to content

Commit 9d16b89

Browse files
authored
Add test for TopologicalSortPass on functions (#2198)
Add a test for `TopologicalSortPass` on functions in a model in `onnxscript/ir/passes/common/topological_sort_test.py`. * Add `test_topological_sort_on_functions` function to verify `TopologicalSortPass` on functions. * Create a function with unsorted nodes and a model containing the function. * Apply `TopologicalSortPass` and verify that the nodes in the function are sorted correctly. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/microsoft/onnxscript/pull/2198?shareId=b8a28bde-b4e4-4037-9628-bf8c02bc144b).
1 parent e404922 commit 9d16b89

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

onnxscript/ir/passes/common/topological_sort_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,41 @@ def test_topological_sort_modified_false(self):
4545
(self.node_a, self.node_b, self.node_c),
4646
)
4747

48+
def test_topological_sort_on_functions(self):
49+
"""Test that TopologicalSortPass works on functions in a model."""
50+
# Create a function with unsorted nodes
51+
func_graph = ir.Graph(
52+
inputs=self.node_a.inputs,
53+
outputs=self.node_c.outputs,
54+
nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes
55+
)
56+
function = ir.Function(
57+
domain="test_domain",
58+
name="test_function",
59+
graph=func_graph,
60+
attributes=[],
61+
)
62+
63+
# Create a model with the function
64+
graph = ir.Graph(
65+
inputs=[],
66+
outputs=[],
67+
nodes=[],
68+
name="test_graph",
69+
)
70+
model = ir.Model(graph, ir_version=10, functions=[function])
71+
72+
# Apply the TopologicalSortPass
73+
result = topological_sort.TopologicalSortPass()(model)
74+
75+
# Verify that the nodes in the function are sorted
76+
sorted_func_nodes = (self.node_a, self.node_b, self.node_c)
77+
self.assertTrue(result.modified)
78+
self.assertEqual(
79+
tuple(result.model.functions[function.identifier()]),
80+
sorted_func_nodes,
81+
)
82+
4883

4984
if __name__ == "__main__":
5085
unittest.main()

0 commit comments

Comments
 (0)