Skip to content

Commit 3d8f64a

Browse files
authored
Turn constant folder and dce into passes (#2109)
Turn constant folder and dce into passes to allow them to be used as individual passes in the future.
1 parent a63c282 commit 3d8f64a

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -797,9 +797,7 @@ def merge_dims(dim1, dim2):
797797
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
798798

799799

800-
class ConstantFolder:
801-
opset_imports: dict[str, int]
802-
800+
class FoldConstantsPass(ir.passes.PassBase):
803801
def __init__(
804802
self,
805803
*,
@@ -812,11 +810,17 @@ def __init__(
812810
self._shape_inference = shape_inference
813811
self._input_size_limit = input_size_limit
814812
self._output_size_limit = output_size_limit
815-
self._init()
816-
817-
def _init(self) -> None:
813+
self.opset_imports: dict[str, int] = {}
818814
self.counts: dict[str, int] = {}
819815
self.sizes: dict[str, int] = {}
816+
self.modified: bool = False
817+
self._state = OptimizerState()
818+
self._reset()
819+
820+
def _reset(self) -> None:
821+
"""Reset internal states for a new run."""
822+
self.counts = {}
823+
self.sizes = {}
820824
self.modified = False
821825
self._state = OptimizerState()
822826

@@ -931,6 +935,7 @@ def process_node(self, node: ir.Node):
931935
sym_value.name,
932936
)
933937
node.replace_input_with(i, sym_value)
938+
self.modified = True
934939
# TODO(rama): consider merging type/other info from both values
935940

936941
# Do incremental shape inference
@@ -1007,6 +1012,8 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
10071012
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
10081013
)
10091014

1015+
self.modified = True
1016+
10101017
# TODO: what about new opset_imports?
10111018
# TODO: track statistics about replaced nodes and sizes of new constants
10121019

@@ -1045,13 +1052,14 @@ def visit_function(self, function: ir.Function) -> None:
10451052
for node in function:
10461053
self.visit_node(node, function)
10471054

1048-
def visit_model(self, model: ir.Model) -> None:
1049-
self._init()
1055+
def call(self, model: ir.Model) -> ir.passes.PassResult:
1056+
self._reset()
10501057
self.opset_imports = model.opset_imports
10511058
self.visit_graph(model.graph)
10521059
for function in model.functions.values():
10531060
# TODO(rama): Should we specialize functions?
10541061
self.visit_function(function)
1062+
return ir.passes.PassResult(model, self.modified)
10551063

10561064

10571065
def fold_constants(
@@ -1066,18 +1074,18 @@ def fold_constants(
10661074
Applies constant folding optimization to the model.
10671075
Returns true iff the model was modified.
10681076
"""
1069-
folder = ConstantFolder(
1077+
folder_pass = FoldConstantsPass(
10701078
external_data_folder=external_data_folder,
10711079
shape_inference=onnx_shape_inference,
10721080
input_size_limit=input_size_limit,
10731081
output_size_limit=output_size_limit,
10741082
)
1075-
folder.visit_model(model)
1076-
for op in folder.counts:
1083+
folder_pass(model)
1084+
for op in folder_pass.counts:
10771085
logger.info(
10781086
"Constant-folded '%s' %s times, with %s size.",
10791087
op,
1080-
folder.counts[op],
1081-
folder.sizes[op],
1088+
folder_pass.counts[op],
1089+
folder_pass.sizes[op],
10821090
)
1083-
return folder.modified
1091+
return folder_pass.modified

onnxscript/optimizer/_remove_unused.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def is_used_output(i: int) -> bool:
5555
out.name = ""
5656

5757

58-
def process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int:
58+
def _process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int:
5959
graph_outputs = frozenset(function_or_graph.outputs)
6060
onnx_opset_version = function_or_graph.opset_imports.get("", None)
6161
count = 0
@@ -75,32 +75,34 @@ def process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int:
7575
if not isinstance(attr, ir.Attr):
7676
continue
7777
if attr.type == ir.AttributeType.GRAPH:
78-
count += process_function_or_graph(attr.as_graph())
78+
count += _process_function_or_graph(attr.as_graph())
7979
elif attr.type == ir.AttributeType.GRAPHS:
8080
for graph in attr.as_graphs():
81-
count += process_function_or_graph(graph)
81+
count += _process_function_or_graph(graph)
8282
return count
8383

8484

85-
def _remove_unused_nodes(model: ir.Model) -> None:
86-
"""Removes unused nodes from a model in IR form."""
87-
count = process_function_or_graph(model.graph)
88-
graph_outputs = frozenset(model.graph.outputs)
89-
initializers = model.graph.initializers
90-
for init in list(initializers.values()):
91-
if not (init in graph_outputs or init.uses()):
92-
del initializers[init.name] # type: ignore[arg-type]
93-
count += 1
94-
95-
for function in model.functions.values():
96-
count += process_function_or_graph(function)
97-
98-
logger.info("Removed %s unused nodes", count)
85+
class RemoveUnusedNodesPass(ir.passes.PassBase):
86+
def call(self, model: ir.Model) -> ir.passes.PassResult:
87+
count = _process_function_or_graph(model.graph)
88+
graph_outputs = frozenset(model.graph.outputs)
89+
initializers = model.graph.initializers
90+
for init in list(initializers.values()):
91+
if not (init in graph_outputs or init.uses()):
92+
assert init.name is not None
93+
del initializers[init.name]
94+
count += 1
95+
for function in model.functions.values():
96+
count += _process_function_or_graph(function)
97+
if count:
98+
logger.info("Removed %s unused nodes", count)
99+
return ir.passes.PassResult(model, modified=True)
100+
return ir.passes.PassResult(model, modified=False)
99101

100102

101103
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
102104
"""Removes unused nodes from a model."""
103105
if isinstance(model, ir.Model):
104-
_remove_unused_nodes(model)
106+
RemoveUnusedNodesPass()(model)
105107
else:
106108
onnxscript.optimizer._legacy._remove_unused_proto.remove_unused_nodes(model)

0 commit comments

Comments
 (0)