Skip to content

Commit d8fa011

Browse files
Add deduplication pass for initializer tensors (#67)
### Summary This PR adds a new graph transformation pass: `DeduplicateInitializersPass`. It removes duplicate initializer tensors (typically model weights) based on a unique fingerprint derived from: - Tensor byte content (`tobytes()`) - Data type (`dtype`) - Shape All redundant initializers are removed, and nodes referencing them are updated to use the canonical (first-seen) tensor. --- ### Implementation Details - Fingerprints are tracked using a dictionary: `(tobytes, dtype, shape) → name` - Redundant initializers are removed using `graph.initializers.pop(...)` - Node inputs are updated via `node.replace_input_with(...)` for correctness and safety --- ### Benefits - Reduces memory and file size by eliminating duplicated weight tensors - Simplifies graph structure for downstream optimization and export --- ### File Added - `src/onnx_ir/passes/common/deduplicate_initializers.py` ### Closes Closes #66 --------- Signed-off-by: Abhishek Herbert Samuel <[email protected]> Signed-off-by: Justin Chu <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 908fa9c commit d8fa011

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Pass for removing duplicated initializer tensors from a graph."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"DeduplicateInitializersPass",
9+
]
10+
11+
12+
import onnx_ir as ir
13+
14+
15+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
16+
"""Remove duplicated initializer tensors from the graph.
17+
18+
This pass detects initializers with identical shape, dtype, and content,
19+
and replaces all duplicate references with a canonical one.
20+
21+
To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
22+
to lift the initializers to the main graph first before running pass.
23+
"""
24+
25+
def __init__(self, size_limit: int = 1024):
26+
super().__init__()
27+
self.size_limit = size_limit
28+
29+
def call(self, model: ir.Model) -> ir.passes.PassResult:
30+
graph = model.graph
31+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
32+
modified = False
33+
34+
for initializer in tuple(graph.initializers.values()):
35+
# TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
36+
# out from the main graph before running this pass.
37+
const_val = initializer.const_value
38+
if const_val is None:
39+
# Skip if initializer has no constant value
40+
continue
41+
42+
if const_val.size > self.size_limit:
43+
continue
44+
45+
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
46+
if key in initializers:
47+
modified = True
48+
ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
49+
assert initializer.name is not None
50+
graph.initializers.pop(initializer.name)
51+
else:
52+
initializers[key] = initializer # type: ignore[index]
53+
54+
return ir.passes.PassResult(model=model, modified=modified)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Unit tests for the initializer_deduplication passes."""
4+
5+
import unittest
6+
7+
import onnx_ir as ir
8+
from onnx_ir.passes.common import initializer_deduplication
9+
10+
11+
class DeduplicateInitializersTest(unittest.TestCase):
12+
def apply_pass(self, model: ir.Model) -> ir.Model:
13+
result = initializer_deduplication.DeduplicateInitializersPass()(model)
14+
return result.model
15+
16+
def test_deduplicates_identical_initializers(self):
17+
model = ir.from_onnx_text(
18+
"""
19+
<ir_version: 10, opset_import: ["" : 17]>
20+
agraph () => ()
21+
<float[3] w1 = {1.0, 2.0, 3.0}, float[3] w2 = {1.0, 2.0, 3.0}> {
22+
sum = Add(w1, w2)
23+
}
24+
"""
25+
)
26+
self.assertEqual(len(model.graph.initializers), 2)
27+
new_model = self.apply_pass(model)
28+
self.assertEqual(len(new_model.graph.initializers), 1)
29+
add_node = new_model.graph[0]
30+
self.assertEqual(add_node.inputs[0], add_node.inputs[1])
31+
32+
def test_initializers_with_different_shapes_not_deduplicated(self):
33+
model = ir.from_onnx_text(
34+
"""
35+
<ir_version: 10, opset_import: ["" : 17]>
36+
agraph () => ()
37+
<float[2] w1 = {1.0, 2.0}, float[1,2] w2 = {1.0, 2.0}> {
38+
sum = Add(w1, w2)
39+
}
40+
"""
41+
)
42+
new_model = self.apply_pass(model)
43+
self.assertEqual(len(new_model.graph.initializers), 2)
44+
45+
def test_initializers_with_different_dtypes_not_deduplicated(self):
46+
model = ir.from_onnx_text(
47+
"""
48+
<ir_version: 10, opset_import: ["" : 17]>
49+
agraph () => ()
50+
<float[2] w1 = {1.0, 2.0}, double[2] w2 = {1.0, 2.0}> {
51+
sum = Add(w1, w2)
52+
}
53+
"""
54+
)
55+
new_model = self.apply_pass(model)
56+
self.assertEqual(len(new_model.graph.initializers), 2)
57+
58+
def test_scalar_initializer_deduplication(self):
59+
model = ir.from_onnx_text(
60+
"""
61+
<ir_version: 10, opset_import: ["" : 17]>
62+
agraph () => ()
63+
<float w1 = {5.0}, float w2 = {5.0}> {
64+
sum = Add(w1, w2)
65+
}
66+
"""
67+
)
68+
new_model = self.apply_pass(model)
69+
self.assertEqual(len(new_model.graph.initializers), 1)
70+
71+
def test_multiple_duplicates(self):
72+
model = ir.from_onnx_text(
73+
"""
74+
<ir_version: 10, opset_import: ["" : 17]>
75+
agraph () => ()
76+
<float[2] w1 = {1.0, 1.0}, float[2] w2 = {1.0, 1.0}, float[2] w3 = {1.0, 1.0}> {
77+
temp = Add(w1, w2)
78+
out = Add(temp, w3)
79+
}
80+
"""
81+
)
82+
new_model = self.apply_pass(model)
83+
self.assertEqual(len(new_model.graph.initializers), 1)
84+
85+
def test_unique_values_not_deduplicated(self):
86+
model = ir.from_onnx_text(
87+
"""
88+
<ir_version: 10, opset_import: ["" : 17]>
89+
agraph () => ()
90+
<float[2] w1 = {1.0, 2.0}, float[2] w2 = {2.0, 1.0}> {
91+
sum = Add(w1, w2)
92+
}
93+
"""
94+
)
95+
new_model = self.apply_pass(model)
96+
self.assertEqual(len(new_model.graph.initializers), 2)
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main()

0 commit comments

Comments
 (0)