@@ -797,9 +797,7 @@ def merge_dims(dim1, dim2):
797
797
return ir .Shape ([merge_dims (dim1 , dim2 ) for dim1 , dim2 in zip (shape1 , shape2 )])
798
798
799
799
800
- class ConstantFolder :
801
- opset_imports : dict [str , int ]
802
-
800
+ class FoldConstantsPass (ir .passes .PassBase ):
803
801
def __init__ (
804
802
self ,
805
803
* ,
@@ -812,11 +810,17 @@ def __init__(
812
810
self ._shape_inference = shape_inference
813
811
self ._input_size_limit = input_size_limit
814
812
self ._output_size_limit = output_size_limit
815
- self ._init ()
816
-
817
- def _init (self ) -> None :
813
+ self .opset_imports : dict [str , int ] = {}
818
814
self .counts : dict [str , int ] = {}
819
815
self .sizes : dict [str , int ] = {}
816
+ self .modified : bool = False
817
+ self ._state = OptimizerState ()
818
+ self ._reset ()
819
+
820
+ def _reset (self ) -> None :
821
+ """Reset internal states for a new run."""
822
+ self .counts = {}
823
+ self .sizes = {}
820
824
self .modified = False
821
825
self ._state = OptimizerState ()
822
826
@@ -931,6 +935,7 @@ def process_node(self, node: ir.Node):
931
935
sym_value .name ,
932
936
)
933
937
node .replace_input_with (i , sym_value )
938
+ self .modified = True
934
939
# TODO(rama): consider merging type/other info from both values
935
940
936
941
# Do incremental shape inference
@@ -1007,6 +1012,8 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
1007
1012
root , node , [node ], replacement .new_nodes , node .outputs , replacement .new_outputs
1008
1013
)
1009
1014
1015
+ self .modified = True
1016
+
1010
1017
# TODO: what about new opset_imports?
1011
1018
# TODO: track statistics about replaced nodes and sizes of new constants
1012
1019
@@ -1045,13 +1052,14 @@ def visit_function(self, function: ir.Function) -> None:
1045
1052
for node in function :
1046
1053
self .visit_node (node , function )
1047
1054
1048
- def visit_model (self , model : ir .Model ) -> None :
1049
- self ._init ()
1055
+ def call (self , model : ir .Model ) -> ir . passes . PassResult :
1056
+ self ._reset ()
1050
1057
self .opset_imports = model .opset_imports
1051
1058
self .visit_graph (model .graph )
1052
1059
for function in model .functions .values ():
1053
1060
# TODO(rama): Should we specialize functions?
1054
1061
self .visit_function (function )
1062
+ return ir .passes .PassResult (model , self .modified )
1055
1063
1056
1064
1057
1065
def fold_constants (
@@ -1066,18 +1074,18 @@ def fold_constants(
1066
1074
Applies constant folding optimization to the model.
1067
1075
Returns true iff the model was modified.
1068
1076
"""
1069
- folder = ConstantFolder (
1077
+ folder_pass = FoldConstantsPass (
1070
1078
external_data_folder = external_data_folder ,
1071
1079
shape_inference = onnx_shape_inference ,
1072
1080
input_size_limit = input_size_limit ,
1073
1081
output_size_limit = output_size_limit ,
1074
1082
)
1075
- folder . visit_model (model )
1076
- for op in folder .counts :
1083
+ folder_pass (model )
1084
+ for op in folder_pass .counts :
1077
1085
logger .info (
1078
1086
"Constant-folded '%s' %s times, with %s size." ,
1079
1087
op ,
1080
- folder .counts [op ],
1081
- folder .sizes [op ],
1088
+ folder_pass .counts [op ],
1089
+ folder_pass .sizes [op ],
1082
1090
)
1083
- return folder .modified
1091
+ return folder_pass .modified
0 commit comments