diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 71f2665923..1c6a10a2c0 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -9,7 +9,7 @@ import logging import math import typing -from typing import Any, Callable, Iterable, Sequence, Union +from typing import Any, Callable, Collection, Iterable, Sequence, Union import numpy as np import onnx @@ -24,12 +24,7 @@ DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024 -def is_control_flow_op(node: ir.Node) -> bool: - graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS} - return any(attr.type in graph_types for attr in node.attributes.values()) - - -non_deterministic_ops = frozenset( +_NON_DETERMINISTIC_OPS = frozenset( { "RandomUniform", "RandomNormal", @@ -40,21 +35,21 @@ def is_control_flow_op(node: ir.Node) -> bool: ) -def is_non_deterministic_op(node: ir.Node) -> bool: - return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain) +logger = logging.getLogger(__name__) -def is_onnx_op(node: ir.Node, op_type: str) -> bool: - return node.op_type == op_type and utils.is_onnx_domain(node.domain) +def _is_control_flow_op(node: ir.Node) -> bool: + graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS} + return any(attr.type in graph_types for attr in node.attributes.values()) -def is_constant_op(node: ir.Node) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain( - node.domain - ) +def _is_non_deterministic_op(node: ir.Node) -> bool: + return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain) -logger = logging.getLogger(__name__) +def _is_onnx_op(node: ir.Node, op_type: str) -> bool: + return node.op_type == op_type and utils.is_onnx_domain(node.domain) + # "Standard" evaluators are used to perform constant-folding. # The API below works only for non-control-flow ops (ops without any graph-attributes). @@ -168,19 +163,6 @@ def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None: def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None: self._sym_value_map[value] = sym_value - def push_initializer_inputs(self) -> None: - self._initializer_inputs.append(set()) - - def pop_initializer_inputs(self) -> None: - self._initializer_inputs.pop() - - def add_initializer_input(self, value: ir.Value) -> None: - assert self._initializer_inputs - self._initializer_inputs[-1].add(value) - - def is_initializer_input(self, value: ir.Value) -> bool: - return any(value in inputs for inputs in self._initializer_inputs) - def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None: const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10) if const_value is not None: @@ -301,6 +283,11 @@ def _get_numpy_value( array = const_value.numpy().view(const_value.dtype.numpy()) except FileNotFoundError: # External data is not available. + logger.warning( + "External data for value '%s' is not available. " + "This may lead to incorrect constant folding.", + val.name, + ) return None assert isinstance(array, np.ndarray) return array @@ -841,28 +828,48 @@ def merge_dims(dim1, dim2): class FoldConstantsPass(ir.passes.InPlacePass): + """A pass that folds constant expressions in the model. + + Attributes: + shape_inference: Whether to perform shape inference. + input_size_limit: Maximum size of input tensors to fold. + output_size_limit: Maximum size of output tensors to fold. + always_fold_ops: Collection of op types that should always be folded. + For ops from the default opset, only op_type is neede (e.g. "Transpose"), + otherwise specify the domain with ``{domain}::{op_type}``. + """ + def __init__( self, *, shape_inference: bool, input_size_limit: int, output_size_limit: int, + always_fold_ops: Collection[str] = frozenset(["Transpose"]), ) -> None: - self._shape_inference = shape_inference - self._input_size_limit = input_size_limit - self._output_size_limit = output_size_limit - self.opset_imports: dict[str, int] = {} - self.counts: dict[str, int] = {} - self.sizes: dict[str, int] = {} - self.modified: bool = False + self.shape_inference = shape_inference + self.input_size_limit = input_size_limit + self.output_size_limit = output_size_limit + ops = [] + for name in always_fold_ops: + domain, op_type = name.split("::", 1) if "::" in name else ("", name) + if domain == "ai.onnx": + domain = "" + ops.append((domain, op_type)) + self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops) + + self._opset_imports: dict[str, int] = {} + self._counts: dict[str, int] = {} + self._sizes: dict[str, int] = {} + self._modified: bool = False self._state = OptimizerState() self._reset() def _reset(self) -> None: """Reset internal states for a new run.""" - self.counts = {} - self.sizes = {} - self.modified = False + self._counts = {} + self._sizes = {} + self._modified = False self._state = OptimizerState() def _do_inference(self, node: ir.Node) -> None: @@ -896,7 +903,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: # TODO: pass in constant values, ir_version try: schema = onnx.defs.get_schema( - node.op_type, self.opset_imports[node.domain], node.domain + node.op_type, self._opset_imports[node.domain], node.domain ) output_types = onnx.shape_inference.infer_node_outputs( schema, @@ -937,7 +944,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None: tensor.name = irvalue.name irvalue.const_value = tensor - if value.nbytes > self._output_size_limit: + if value.nbytes > self.output_size_limit: # Handle examples like Transpose(weight) to be folded even if the size is large, # as long as weight has no other uses. This won't increase model size. removed_input_size = 0 @@ -967,6 +974,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None: return node def process_node(self, node: ir.Node) -> Replacement | None: + """Process a node and return a Replacement if the node can be replaced.""" for i, value in enumerate(node.inputs): sym_value = self._state.get_sym_value(value) if isinstance(sym_value, ir.Value): @@ -977,16 +985,16 @@ def process_node(self, node: ir.Node) -> Replacement | None: sym_value.name, ) node.replace_input_with(i, sym_value) - self.modified = True + self._modified = True # TODO(rama): consider merging type/other info from both values # Do incremental shape inference - if self._shape_inference and not is_control_flow_op(node): + if self.shape_inference and not _is_control_flow_op(node): self._do_inference(node) - if node.domain not in self.opset_imports: + if node.domain not in self._opset_imports: return None - version = self.opset_imports[node.domain] + version = self._opset_imports[node.domain] op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) for optimizer in op_optimizers: assert optimizer @@ -999,31 +1007,58 @@ def process_node(self, node: ir.Node) -> Replacement | None: output = [output] return Replacement(output, context.nodes) - if is_control_flow_op(node) or is_non_deterministic_op(node): + if _is_control_flow_op(node) or _is_non_deterministic_op(node): return None - if is_onnx_op(node, "Constant"): + if _is_onnx_op(node, "Constant"): _process_constant_node(node) return None - input_values = [_get_numpy_value(x) for x in node.inputs] - if any(x is None for x in input_values): - return None - - if any(self._state.is_initializer_input(x) for x in node.inputs): # type: ignore[arg-type] + if any(x.is_graph_input() for x in node.inputs if x is not None): + # Do not fold any graph inputs to preserve graph signature return None - if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr] + # Ensure all node inputs are constants + if any(x.const_value is None for x in node.inputs if x is not None): if logger.isEnabledFor(logging.DEBUG): - input_sizes = [input.size for input in input_values] # type: ignore[union-attr] logger.debug( - "Skipping constant folding for op %s due to large input size: %s", - node.op_type, - input_sizes, + "Skipping constant folding for node %s because it has non-constant inputs", + node, + [x.name for x in node.inputs if x is not None], ) return None - # Filter out bfloat16 cases? + input_tensors = [x.const_value if x is not None else None for x in node.inputs] + + if any( + tensor.nbytes > self.input_size_limit + for tensor in input_tensors + if tensor is not None + ): + if (node.domain, node.op_type) in self.always_fold_ops and all( + len(input.consumers()) == 1 for input in node.inputs if input is not None + ): + # If the op is in always_fold_ops and all inputs are used only by this node, + # we can still fold it even if the input size exceeds the limit. + logger.debug( + "Folding large constant for node %s because it is in the always_fold_ops list", + node, + ) + else: + # Skip folding large tensors + if logger.isEnabledFor(logging.DEBUG): + input_sizes = [ + tensor.nbytes for tensor in input_tensors if tensor is not None + ] + logger.debug( + "Skipping constant folding for node %s due to large input size: %s", + node, + input_sizes, + ) + return None + + input_values = [_get_numpy_value(x) for x in node.inputs] + def convert(av): if av.type == ir.AttributeType.TENSOR: return ir.serde.serialize_tensor(av.value) @@ -1038,7 +1073,7 @@ def convert(av): return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): replacement = self.new_constant(node, outputs) - if is_onnx_op(node, "ConstantOfShape") or replacement is None: + if _is_onnx_op(node, "ConstantOfShape") or replacement is None: return None return Replacement(replacement.outputs, [replacement]) else: @@ -1054,7 +1089,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs ) - self.modified = True + self._modified = True # TODO: what about new opset_imports? # TODO: track statistics about replaced nodes and sizes of new constants @@ -1079,13 +1114,6 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None: self.replace_node(node, replacement, root) def visit_graph(self, graph: ir.Graph) -> None: - # Track inputs that have a const_value (which is really a default-value, and should not - # be used for constant-folding). - self._state.push_initializer_inputs() - for input in graph.inputs: - if input.const_value is not None: - self._state.add_initializer_input(input) - for node in graph: self.visit_node(node, graph) @@ -1103,22 +1131,20 @@ def visit_graph(self, graph: ir.Graph) -> None: # Rename sym_value to match the output name sym_value.name = output.name graph.outputs[i] = sym_value - self.modified = True - - self._state.pop_initializer_inputs() + self._modified = True def visit_function(self, function: ir.Function) -> None: for node in function: self.visit_node(node, function) - def call(self, model: ir.Model) -> ir.passes.PassResult: + def call(self, model: ir.Model) -> FoldConstantsResult: self._reset() - self.opset_imports = model.opset_imports + self._opset_imports = model.opset_imports self.visit_graph(model.graph) for function in model.functions.values(): # TODO(rama): Should we specialize functions? self.visit_function(function) - return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map) + return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map) def _sym_value_can_replace_graph_output( @@ -1155,6 +1181,7 @@ def fold_constants( onnx_shape_inference: bool = False, input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + always_fold_ops: Collection[str] = frozenset(["Transpose"]), ) -> FoldConstantsResult: """ Applies constant folding optimization to the model. @@ -1169,6 +1196,10 @@ def fold_constants( output_size_limit: The maximum size (in bytes) of output tensors that can be stored after constant folding. Defaults to `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`. + always_fold_ops: A collection of op types that should always be folded, + regardless of their input or output sizes. For ops from the default opset, + only op_type is neede (e.g. "Transpose"), otherwise specify the domain + with ``{domain}::{op_type}``. Returns: An instance of `FoldConstantsResult`. @@ -1178,5 +1209,6 @@ def fold_constants( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, + always_fold_ops=always_fold_ops, ) return folder_pass(model) # type: ignore[return-value] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 5a98cb5d51..20f116c7d9 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -536,14 +536,41 @@ def test_gather_symdim(self): optimized = self._fold(model) self.assertEqual(optimized.graph.node(-1).op_type, "Identity") - def test_large_transpose(self): + def test_input_size_limit(self): + model_text = """ + + agraph (float[M, 256] x) => (float[M, 256] z) + # placeholder for large initializer of shape [256, 256] + { + w_squared = Mul (w, w) + z = Add (x, w_squared) + } + """ + model = ir.from_onnx_text(model_text) + w = model.graph.initializers["w"] + w.shape = ir.Shape([256, 256]) + w.const_value = ir.tensor(np.random.random((256, 256)).astype(np.float32)) + + # Input size limit will prevent folding of Mul op + optimized = self._fold(model, input_size_limit=3 * 256 * 256) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Mul", "Add"]) + + # Input size limit will allow folding of Mul op + # Since there is no increase in model-size, output-size is not a concern. + optimized = self._fold( + model, input_size_limit=4 * 256 * 256, output_size_limit=4 * 256 * 256 + ) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Constant", "Add"]) + + def test_transpose_is_always_folded(self): model_text = """ agraph (float[M, 256] x) => (float[M, 512] z) # placeholder for large initializer of shape [512, 256] { - wt = Transpose (w) - z = MatMul (x, wt) + z = Transpose (w) } """ model = ir.from_onnx_text(model_text) @@ -551,16 +578,10 @@ def test_large_transpose(self): w.shape = ir.Shape([512, 256]) w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32)) - # Input size limit will prevent folding of Transpose op - optimized = self._fold(model, input_size_limit=3 * 512 * 256) - ops = [node.op_type for node in optimized.graph] - self.assertEqual(ops, ["Transpose", "MatMul"]) - - # Input size limit will allow folding of Transpose op - # Since there is no increase in model-size, output-size is not a concern. - optimized = self._fold(model, input_size_limit=4 * 512 * 256) + # Input size limit will not prevent folding of Transpose op + optimized = self._fold(model, input_size_limit=1) ops = [node.op_type for node in optimized.graph] - self.assertEqual(ops, ["Constant", "MatMul"]) + self.assertEqual(ops, ["Constant"]) def test_multi_graph_identity_output_preserves_output_name(self): model = """