|
15 | 15 | logger = logging.getLogger(__name__) |
16 | 16 |
|
17 | 17 |
|
| 18 | +def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None: |
| 19 | + def merge_dims(dim1, dim2): |
| 20 | + if dim1 == dim2: |
| 21 | + return dim1 |
| 22 | + if not isinstance(dim1, ir.SymbolicDim): |
| 23 | + return dim1 # Prefer int value over symbolic dim |
| 24 | + if not isinstance(dim2, ir.SymbolicDim): |
| 25 | + return dim2 |
| 26 | + if dim1.value is None: |
| 27 | + return dim2 |
| 28 | + return dim1 |
| 29 | + |
| 30 | + if shape1 is None: |
| 31 | + return shape2 |
| 32 | + if shape2 is None: |
| 33 | + return shape1 |
| 34 | + if len(shape1) != len(shape2): |
| 35 | + raise ValueError("Shapes must have the same rank.") |
| 36 | + return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) |
| 37 | + |
| 38 | + |
18 | 39 | class IdentityEliminationPass(ir.passes.InPlacePass): |
19 | 40 | """Pass for eliminating redundant Identity nodes. |
20 | 41 |
|
@@ -75,6 +96,11 @@ def _try_eliminate_identity_node(self, node: ir.Node) -> bool: |
75 | 96 | if output_is_graph_output and input_is_graph_input: |
76 | 97 | return False |
77 | 98 |
|
| 99 | + # Copy over shape/type if the output has more complete information |
| 100 | + input_value.shape = _merge_shapes(input_value.shape, output_value.shape) |
| 101 | + if input_value.type is None: |
| 102 | + input_value.type = output_value.type |
| 103 | + |
78 | 104 | # Case 1 & 2 (merged): Eliminate the identity node |
79 | 105 | # Replace all uses of output with input |
80 | 106 | ir.convenience.replace_all_uses_with(output_value, input_value) |
|
0 commit comments