Skip to content

Commit 191ddb4

Browse files
Add deduplication pass for initializer tensors (#66)
1 parent c4e8371 commit 191ddb4

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from onnx_ir import ir
2+
from onnx_ir.passes.base import GraphTransformPass
3+
4+
5+
class DeduplicateInitializersPass(GraphTransformPass):
6+
"""
7+
This pass removes duplicate initializer tensors from the graph.
8+
9+
It identifies duplicates based on a content-based fingerprint consisting of:
10+
- Tensor byte content (`tobytes()`)
11+
- Data type (`dtype`)
12+
- Shape
13+
14+
All duplicates are replaced with the first (canonical) occurrence, and node
15+
inputs referring to redundant initializers are updated accordingly.
16+
"""
17+
18+
def apply(self, graph: ir.Graph) -> ir.Graph:
19+
seen = {} # Maps (tobytes, dtype, shape) -> canonical initializer name
20+
name_map = {} # Maps duplicate initializer name -> canonical name
21+
22+
# Iterate over all initializers in the graph
23+
for initializer in list(graph.initializers.values()):
24+
key = (
25+
initializer.const_value.tobytes(), # Content fingerprint
26+
initializer.const_value.dtype, # Data type
27+
tuple(initializer.const_value.shape), # Shape tuple
28+
)
29+
30+
if key in seen:
31+
# Found a duplicate: store the name mapping and remove it from graph
32+
canonical_name = seen[key]
33+
name_map[initializer.name] = canonical_name
34+
graph.initializers.pop(initializer.name)
35+
else:
36+
# First time seeing this tensor → keep it
37+
seen[key] = initializer.name
38+
39+
# Update node inputs to use the canonical initializer names
40+
for node in graph:
41+
for i, input_value in enumerate(node.inputs):
42+
if input_value is not None and input_value.name in name_map:
43+
# Replace input with the deduplicated initializer
44+
new_name = name_map[input_value.name]
45+
replacement = graph.initializers[new_name]
46+
node.replace_input_with(i, replacement)
47+
48+
return graph
49+

0 commit comments

Comments
 (0)