Skip to content

Commit 312219b

Browse files
[pass] Avoid lifting tensors that are too small to initializers (#2190)
Tensors with too few elements are usually not weights and are plenty. Lifting them will make the initializer list very noisy. I added a parameter `size_limit` to control this and defaulted it to 16. --------- Co-authored-by: Shubham Bhokare <[email protected]>
1 parent 8f96dc9 commit 312219b

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
2323
Attributes:
2424
lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
2525
Default to False, where only Constants with the ``value`` attribute are lifted.
26+
size_limit: The minimum size of the tensor to be lifted. If the tensor contains
27+
number of elements less than size_limit, it will not be lifted. Default is 16.
2628
"""
2729

28-
def __init__(self, lift_all_constants: bool = False):
30+
def __init__(self, lift_all_constants: bool = False, size_limit: int = 16):
2931
super().__init__()
30-
self._lift_all_constants = lift_all_constants
32+
self.lift_all_constants = lift_all_constants
33+
self.size_limit = size_limit
3134

3235
def call(self, model: ir.Model) -> ir.passes.PassResult:
3336
count = 0
@@ -79,16 +82,17 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
7982

8083
def _constant_node_attribute_to_tensor(
8184
self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
82-
) -> ir.Tensor | None:
85+
) -> ir.TensorProtocol | None:
8386
"""Convert constant node attribute to tensor."""
84-
if not self._lift_all_constants and attr_name != "value":
87+
if not self.lift_all_constants and attr_name != "value":
8588
logger.debug(
8689
"Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
8790
)
8891
return None
8992

93+
tensor: ir.TensorProtocol
9094
if attr_name == "value":
91-
tensor = attr_value.as_tensor() # type: ignore[union-attr]
95+
tensor = attr_value.as_tensor()
9296
elif attr_name == "value_int":
9397
tensor = ir.tensor(
9498
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
@@ -110,5 +114,15 @@ def _constant_node_attribute_to_tensor(
110114
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
111115
)
112116
else:
113-
tensor = None
114-
return tensor # type: ignore[return-value]
117+
raise ValueError(
118+
f"Unsupported constant node '{node.name}' attribute '{attr_name}'"
119+
)
120+
121+
if tensor.size < self.size_limit:
122+
logger.debug(
123+
"Tensor from node '%s' has less than %s elements",
124+
node.name,
125+
self.size_limit,
126+
)
127+
return None
128+
return tensor

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
5656

5757
# Perform lift constants to initializers
5858
result = constant_manipulation.LiftConstantsToInitializersPass(
59-
lift_all_constants=lift_all_constants
59+
lift_all_constants=lift_all_constants, size_limit=0
6060
)(model)
6161
self.assertTrue(result.modified)
6262
# Check that the constant node is lifted to an initializer
@@ -130,7 +130,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
130130
ir_version=10,
131131
)
132132
result = constant_manipulation.LiftConstantsToInitializersPass(
133-
lift_all_constants=lift_all_constants
133+
lift_all_constants=lift_all_constants, size_limit=0
134134
)(model)
135135
self.assertTrue(result.modified)
136136
# Check that the constant node is lifted to the subgraph initializers
@@ -206,7 +206,7 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
206206

207207
# Perform lift constants to initializers
208208
result = constant_manipulation.LiftConstantsToInitializersPass(
209-
lift_all_constants=lift_all_constants
209+
lift_all_constants=lift_all_constants, size_limit=0
210210
)(model)
211211
if lift_all_constants:
212212
self.assertTrue(result.modified)
@@ -249,3 +249,7 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self):
249249
self.assertFalse(result.modified)
250250
# Check that the constant node is not lifted to an initializer
251251
self.assertEqual(len(result.model.graph.initializers), 0)
252+
253+
254+
if __name__ == "__main__":
255+
unittest.main()

0 commit comments

Comments
 (0)