Skip to content

Commit 73c1d2b

Browse files
authored
Implement shape merging in identity elimination pass
Following microsoft/onnxscript#2588. Handle shape info as well. Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 49a361c commit 73c1d2b

1 file changed

Lines changed: 26 additions & 0 deletions

File tree

src/onnx_ir/passes/common/identity_elimination.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18+
def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
19+
def merge_dims(dim1, dim2):
20+
if dim1 == dim2:
21+
return dim1
22+
if not isinstance(dim1, ir.SymbolicDim):
23+
return dim1 # Prefer int value over symbolic dim
24+
if not isinstance(dim2, ir.SymbolicDim):
25+
return dim2
26+
if dim1.value is None:
27+
return dim2
28+
return dim1
29+
30+
if shape1 is None:
31+
return shape2
32+
if shape2 is None:
33+
return shape1
34+
if len(shape1) != len(shape2):
35+
raise ValueError("Shapes must have the same rank.")
36+
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
37+
38+
1839
class IdentityEliminationPass(ir.passes.InPlacePass):
1940
"""Pass for eliminating redundant Identity nodes.
2041
@@ -75,6 +96,11 @@ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
7596
if output_is_graph_output and input_is_graph_input:
7697
return False
7798

99+
# Copy over shape/type if the output has more complete information
100+
input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
101+
if input_value.type is None:
102+
input_value.type = output_value.type
103+
78104
# Case 1 & 2 (merged): Eliminate the identity node
79105
# Replace all uses of output with input
80106
ir.convenience.replace_all_uses_with(output_value, input_value)

0 commit comments

Comments
 (0)