Skip to content

Commit 77fba51

Browse files
authored
[pass] Update LiftSubgraphInitializersToMainGraphPass to disallow variable shadowing (#2348)
Variable shadowing (reusing value names) is disallowed in ONNX across the main graph and subgraphs according to the spec (https://github.com/onnx/onnx/pull/6955/files). This change updates to the logic to check and raise on such cases. A subsequent PR will implement #1432 to allow users to fix names explicitly.
1 parent b5b51c0 commit 77fba51

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,37 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
137137
This pass lifts the initializers of a subgraph to the main graph.
138138
It is used to ensure that the initializers are available in the main graph
139139
for further processing or optimization.
140+
141+
Initializers that are also graph inputs will not be lifted.
142+
143+
Preconditions:
144+
- All initializers in the model must have unique names across the main graph and subgraphs.
140145
"""
141146

147+
def requires(self, model: ir.Model) -> None:
148+
"""Ensure all initializer names are unique."""
149+
registered_initializer_names: set[str] = set()
150+
duplicated_initializers: list[ir.Value] = []
151+
for graph in model.graphs():
152+
for initializer in graph.initializers.values():
153+
if initializer.name is None:
154+
raise ir.passes.PreconditionError(
155+
f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}"
156+
)
157+
if initializer.name in registered_initializer_names:
158+
duplicated_initializers.append(initializer)
159+
else:
160+
registered_initializer_names.add(initializer.name)
161+
if duplicated_initializers:
162+
raise ir.passes.PreconditionError(
163+
"Found duplicated initializers in the model. "
164+
"Initializer name must be unique across the main graph and subgraphs. "
165+
"Please ensure all initializers have unique names. Duplicated: "
166+
f"{duplicated_initializers!r}"
167+
)
168+
142169
def call(self, model: ir.Model) -> ir.passes.PassResult:
143170
count = 0
144-
registered_initializer_names: dict[str, int] = {}
145171
for graph in model.graphs():
146172
if graph is model.graph:
147173
continue
@@ -156,15 +182,6 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
156182
continue
157183
# Remove the initializer from the subgraph
158184
graph.initializers.pop(name)
159-
# To avoid name conflicts, we need to rename the initializer
160-
# to a unique name in the main graph
161-
if name in registered_initializer_names:
162-
name_count = registered_initializer_names[name]
163-
initializer.name = f"{name}_{name_count}"
164-
registered_initializer_names[name] = name_count + 1
165-
else:
166-
assert initializer.name is not None
167-
registered_initializer_names[initializer.name] = 1
168185
model.graph.register_initializer(initializer)
169186
count += 1
170187
logger.debug(

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,12 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self):
248248
class TestLiftSubgraphInitializersToMainGraphPass(unittest.TestCase):
249249
@parameterized.parameterized.expand(
250250
[
251-
("then_initializer", "else_initializer"),
252-
("initializer", "initializer"),
251+
("unique_init_names", "then_initializer", "else_initializer"),
252+
("duplicated_init_names", "initializer", "initializer"),
253253
]
254254
)
255255
def test_pass_with_lifting_constants_to_initializers_within_subgraph(
256-
self, then_initializer_name, else_initializer_name
256+
self, _: str, then_initializer_name: str, else_initializer_name: str
257257
):
258258
input_value = ir.Value(
259259
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
@@ -311,6 +311,13 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
311311
graph=main_graph,
312312
ir_version=10,
313313
)
314+
if then_initializer_name == else_initializer_name:
315+
with self.assertRaisesRegex(
316+
ir.passes.PreconditionError,
317+
"Initializer name must be unique across the main graph and subgraphs",
318+
):
319+
constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model)
320+
return
314321
result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model)
315322
self.assertTrue(result.modified)
316323

@@ -325,12 +332,12 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
325332

326333
@parameterized.parameterized.expand(
327334
[
328-
("then_initializer", "else_initializer"),
329-
("initializer", "initializer"),
335+
("unique_init_names", "then_initializer", "else_initializer"),
336+
("duplicated_init_names", "initializer", "initializer"),
330337
]
331338
)
332339
def test_pass_does_not_lift_initialized_inputs_in_subgraph(
333-
self, then_initializer_name, else_initializer_name
340+
self, _: str, then_initializer_name: str, else_initializer_name: str
334341
):
335342
input_value = ir.Value(
336343
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
@@ -390,6 +397,13 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph(
390397
graph=main_graph,
391398
ir_version=10,
392399
)
400+
if then_initializer_name == else_initializer_name:
401+
with self.assertRaisesRegex(
402+
ir.passes.PreconditionError,
403+
"Initializer name must be unique across the main graph and subgraphs",
404+
):
405+
constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model)
406+
return
393407
result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model)
394408
self.assertTrue(result.modified)
395409

0 commit comments

Comments
 (0)