|
| 1 | +from onnx_ir import ir |
| 2 | +from onnx_ir.passes.base import GraphTransformPass |
| 3 | + |
| 4 | + |
| 5 | +class DeduplicateInitializersPass(GraphTransformPass): |
| 6 | + """ |
| 7 | + This pass removes duplicate initializer tensors from the graph. |
| 8 | +
|
| 9 | + It identifies duplicates based on a content-based fingerprint consisting of: |
| 10 | + - Tensor byte content (`tobytes()`) |
| 11 | + - Data type (`dtype`) |
| 12 | + - Shape |
| 13 | +
|
| 14 | + All duplicates are replaced with the first (canonical) occurrence, and node |
| 15 | + inputs referring to redundant initializers are updated accordingly. |
| 16 | + """ |
| 17 | + |
| 18 | + def apply(self, graph: ir.Graph) -> ir.Graph: |
| 19 | + seen = {} # Maps (tobytes, dtype, shape) -> canonical initializer name |
| 20 | + name_map = {} # Maps duplicate initializer name -> canonical name |
| 21 | + |
| 22 | + # Iterate over all initializers in the graph |
| 23 | + for initializer in list(graph.initializers.values()): |
| 24 | + key = ( |
| 25 | + initializer.const_value.tobytes(), # Content fingerprint |
| 26 | + initializer.const_value.dtype, # Data type |
| 27 | + tuple(initializer.const_value.shape), # Shape tuple |
| 28 | + ) |
| 29 | + |
| 30 | + if key in seen: |
| 31 | + # Found a duplicate: store the name mapping and remove it from graph |
| 32 | + canonical_name = seen[key] |
| 33 | + name_map[initializer.name] = canonical_name |
| 34 | + graph.initializers.pop(initializer.name) |
| 35 | + else: |
| 36 | + # First time seeing this tensor → keep it |
| 37 | + seen[key] = initializer.name |
| 38 | + |
| 39 | + # Update node inputs to use the canonical initializer names |
| 40 | + for node in graph: |
| 41 | + for i, input_value in enumerate(node.inputs): |
| 42 | + if input_value is not None and input_value.name in name_map: |
| 43 | + # Replace input with the deduplicated initializer |
| 44 | + new_name = name_map[input_value.name] |
| 45 | + replacement = graph.initializers[new_name] |
| 46 | + node.replace_input_with(i, replacement) |
| 47 | + |
| 48 | + return graph |
| 49 | + |
0 commit comments