diff --git a/onnxscript/onnxrewriter/__init__.py b/onnxscript/onnxrewriter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onnxscript/onnxrewriter/ir/__init__.py b/onnxscript/onnxrewriter/ir/__init__.py new file mode 100644 index 0000000000..5ae285c015 --- /dev/null +++ b/onnxscript/onnxrewriter/ir/__init__.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import dataclasses +from collections import deque +from typing import List, Tuple, Union + +import numpy as np +import onnx + + +class Unknown: + """A special value used to indicate that a value is not a statically known constant. + + We use this instead of None because None is a valid constant value (since ONNX + supports the Optional type). + """ + + instance = None + + def __init__(self) -> None: + if Unknown.instance is not None: + raise ValueError("Unknown.instance is already set") + Unknown.instance = self + + +# Singleton instance of Unknown +unknown = Unknown() +NotConstant = unknown + +# ConcreteValue: This type represents constant values that an ONNX variable can take. +# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals, +# maps, etc. +# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto. +# A uniform representation would be helpful, but we should avoid unnecessary conversions for +# large tensors. Should be cleaned up in the new IR. +ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None] + +# SymbolicValue: This information is used to enable partial-evaluation and specialization +# of sequence operations, as well as elimination of redundant Identity ops. +# The symbolic value of a variable X can be: +# - a string with the value "Y", indicating that "X" is a copy of "Y" +# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values +# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to +# "SequenceConstruct(A, B, C)". +# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of +# tensors, etc. However, we currently only handle lists of tensors. + +SymbolicValue = Union[str, List[str]] + +FunctionId = Tuple[str, str, str] + + +def get_function_id(function: onnx.FunctionProto) -> FunctionId: + return (function.domain, function.name, getattr(function, "overload", "")) + + +def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId: + return (node.domain, node.op_type, getattr(node, "overload", "")) + + +@dataclasses.dataclass +class StaticValueInfo: + name: str + value: ConcreteValue = NotConstant + type: onnx.TypeProto | None = None + symbolic_value: SymbolicValue | None = None + + def is_copy(self) -> bool: + return isinstance(self.symbolic_value, str) + + def tensor_shape_proto(self) -> onnx.TensorShapeProto | None: + """Returns the shape of a tensor or None. + + A return value of None could mean that the type is unknown or that the type is not a tensor + or that the tensor shape (that is, even the rank) is unknown. + """ + type = self.type + if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + return type.tensor_type.shape + return None + + @property + def shape(self) -> list[str | int | None] | None: + """Returns the shape in a list. + + Str means that the shape is dynamic. + """ + type = self.type + if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + dims = [] + for dim in type.tensor_type.shape.dim: + if dim.HasField("dim_param"): + dims.append(dim.dim_param) + elif dim.HasField("dim_value"): + dims.append(dim.dim_value) + else: + dims.append(None) + return dims + if self.value_as_np_array is not None: + return list(self.value_as_np_array.shape) + return None + + @property + def element_type(self) -> int | None: + """Returns the element type of a tensor, or None if type is not known or is not a tensor.""" + type = self.type + if type and type.HasField("tensor_type"): + return type.tensor_type.elem_type + return None + + def identity_merge_from(self, other: StaticValueInfo) -> None: + """Merge the value of other into self. + + This models the effect of an identity (copy) operation. + This will update static-analysis information based on incoming value. + """ + if not isinstance(other, StaticValueInfo): + raise TypeError(f"Cannot merge {other} into {self}.") + if other.value is not NotConstant: + self.value = other.value + # TODO: merge and combine best shape information from both types. + if other.tensor_shape_proto() is not None and other.element_type is not None: + self.type = other.type + # We cannot copy symbolic value across different scopes. + + # WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo + # does not fill in the following fields. These fields are filled in by the IRBuilder + # which constructs the IR from the ONNX model. + node: Node | None = None + uses: list[Node] = dataclasses.field(default_factory=list) + output_index: int | None = None + is_output: bool = False + + @property + def const_value(self) -> ConcreteValue: + return self.value + + @property + def value_as_np_array(self) -> np.ndarray | None: + if isinstance(self.value, np.ndarray): + return self.value + if isinstance(self.value, onnx.TensorProto): + return onnx.numpy_helper.to_array(self.value) + return None + + def def_node(self) -> Node | None: + return self.node + + def def_index(self) -> int: + return self.output_index + + def is_same_as(self, other: StaticValueInfo) -> bool: + """Returns true if this value represents the same IR object as the other value. + + This is *not* value-equality, but rather object-equality. + """ + return self is other + + def __str__(self) -> str: + shape = self.shape + if shape is not None: + shape = [str(dim) for dim in shape] + shape_str = f"[{', '.join(shape)}]" + else: + shape_str = "None" + return ( + f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, " + f"{'has const value' if self.value is not unknown else 'no const value'}.)" + ) + + +Value = StaticValueInfo + + +class Model: + def __init__(self) -> None: + self.gen_var_counter: int = 0 + + def set( + self, + model_proto: onnx.ModelProto, + graph: Graph, + functions: list[Function], + version_map: dict[str, int], + ) -> None: + """TODO. This is a temporary patch.""" + self.original_model_proto = model_proto + self.graph = graph + self.functions = functions + self.version_map = version_map + + def make_new_name(self): + # Temporary hack. + self.gen_var_counter += 1 + return f"_gen_{self.gen_var_counter}" + + def __str__(self) -> str: + # TODO: Naive string representation for debugging. Need to improve this. + return "\n".join( + [ + f"ModelGraph: {self.graph}", + f"Functions: {self.functions}", + f"VersionMap: {self.version_map}", + ] + ) + + +class Graph: + def __init__(self, graph_proto: onnx.GraphProto): + self.original_graph_proto = graph_proto + self.nodes: deque[Node] = deque() + self.values: dict[str, Value] = {} + + @property + def name(self) -> str: + return self.original_graph_proto.name + + def __str__(self) -> str: + return "\n".join( + [ + "Graph", + f"Nodes: {[str(n) for n in self.nodes]}", + f"Values: {[str(v) for v in self.values]}", + ] + ) + + +class Function: + def __init__(self, function_proto: onnx.FunctionProto): + self.original_function_proto = function_proto + self.nodes = deque() + self.values = {} + + @property + def id(self) -> FunctionId: + return (self.domain, self.name, self.overload) + + @property + def domain(self) -> str: + return self.original_function_proto.domain + + @property + def name(self) -> str: + return self.original_function_proto.name + + @property + def overload(self) -> str: + return getattr(self.original_function_proto, "overload", "") + + def __str__(self) -> str: + return "\n".join( + [ + "Function", + f"Nodes: {[str(n) for n in self.nodes]}", + f"Values: {[str(v) for v in self.values]}", + ] + ) + + +class RefAttr: + def __init__(self, name: str, ref_attr_name: str, type) -> None: + self.name = name + self.ref_attr_name = ref_attr_name + self.type = type + + def to_proto(self) -> onnx.AttributeProto: + attr_proto = onnx.AttributeProto() + attr_proto.name = self.name + attr_proto.ref_attr_name = self.ref_attr_name + attr_proto.type = self.type + return attr_proto + + +class Node: + def __init__(self, node_proto: onnx.NodeProto) -> None: + self.original_node_proto = node_proto + self.domain: str = node_proto.domain + self.version: int | None = None + self.op_type: str = node_proto.op_type + self.inputs: list[Value | None] = [] + self.outputs: list[Value | None] = [] + self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {} + + def get_attribute(self, name: str) -> int | float | None: + return self.attributes.get(name, None) + + def __str__(self) -> str: + return "\n".join( + [ + "Node", + f"OpType: {self.op_type}", + f"Inputs: {self.inputs}", + f"Outputs: {self.outputs}", + f"Attributes: {self.attributes}", + ] + ) diff --git a/onnxscript/onnxrewriter/ir/irbuilder.py b/onnxscript/onnxrewriter/ir/irbuilder.py new file mode 100644 index 0000000000..2b1abc9994 --- /dev/null +++ b/onnxscript/onnxrewriter/ir/irbuilder.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import warnings +from typing import Any + +import onnx + +from onnxrewriter import ir +from onnxrewriter.ir import visitor +from onnxrewriter.utils import utils + +""" NOTE: IRBuilder and function visiting + +Current IRBuilder is designed to visit function by definition, instead of function by callsite. +This has the following implications during visiting: +- Prior to IR 10 / ONNX 1.16, value_info is not defined in function. They are experimentally defined under + main graph for models produced by PyTorch 2.2+ dynamo onnx exporter. Hence a workaround is required in `process_node` + to load function value info from a pre-processed `FunctionShapeEnv` object. + Post IR 10, using `process_value_info` method is enough to retrieve and process both function and graph + value_info. +- ref_attr_name is not resolved during visiting, because it requires the function callsite information. + +""" + + +class IRBuilder: + def __init__(self): + self._current_graphs: list[ir.Graph] = [] + # See NOTE: IRBuilder and function visiting + self._current_function: ir.Function | None = None + self._function_subgraphs: list[ir.Graph] = [] + self.functions: dict[ir.FuntionId, ir.Function] = {} + + def visit_model(self, model_proto: onnx.ModelProto) -> ir.Model: + self._function_shape_env = visitor.FunctionShapeEnv() + self._function_shape_env.load_from_model_proto(model_proto) + self._ir_version = model_proto.ir_version + version_map = {x.domain: x.version for x in model_proto.opset_import} + functions = [ + self.visit_function(function) for function in model_proto.functions + ] + self.functions = {function.id: function for function in functions} + graph = self.visit_graph(model_proto.graph) + model = ir.Model() + model.set(model_proto, graph, functions, version_map) + return model + + def visit_graph(self, graph: onnx.GraphProto) -> ir.Graph: + self.enter_graph(ir.Graph(graph)) + for input in graph.input: + self.process_graph_input(input) + for init in graph.initializer: + self.process_initializer(init) + for node in graph.node: + self.process_node(node) + for output in graph.output: + self.process_graph_output(output) + for value_info in graph.value_info: + self.process_value_info(value_info) + return self.exit_graph() + + def visit_function(self, function: onnx.FunctionProto) -> ir.Function: + self._current_function = ir.Function(function) + for input in function.input: + self.process_function_input(input) + for node in function.node: + self.process_node(node) + for output in function.output: + self.process_function_output(output) + for value_info in getattr(function, "value_info", []): + self.process_value_info(value_info) + function_ir = self._current_function + self._current_function = None + return function_ir + + @property + def current_graph_or_function(self) -> ir.Graph | ir.Function: + if self._function_subgraphs: + assert self._current_function is not None + return self._function_subgraphs[-1] + if self._current_function is not None: + return self._current_function + return self._current_graphs[-1] + + def enter_graph(self, graph: ir.Graph): + if self._current_function is not None: + self._function_subgraphs.append(graph) + else: + self._current_graphs.append(graph) + + def exit_graph(self) -> ir.Graph: + if self._current_function is not None: + return self._function_subgraphs.pop() + else: + return self._current_graphs.pop() + + def _lookup_from_graphs(self, name: str, graphs: list[ir.Graph]) -> ir.Value | None: + for graph in reversed(graphs): + value = graph.values.get(name, None) + if value is not None: + return value + return None + + def lookup(self, name: str) -> ir.Value | None: + if self._current_function is not None: + value = self._lookup_from_graphs(name, self._function_subgraphs) + if value is not None: + return value + return self._current_function.values.get(name, None) + return self._lookup_from_graphs(name, self._current_graphs) + + def bind(self, name: str, value: ir.Value): + self.current_graph_or_function.values[name] = value + + def process_graph_input(self, input: onnx.ValueInfoProto): + newvalue = ir.Value(name=input.name, type=input.type) + self.bind(input.name, newvalue) + + def process_initializer(self, init: onnx.TensorProto): + # TODO(titaiwang): Take care of the case where the initializer is already defined? + if init.name not in self.current_graph_or_function.values: + newvalue = ir.Value(name=init.name, value=init) + self.bind(init.name, newvalue) + + def process_node(self, node): + node_ir = ir.Node(node) + self.current_graph_or_function.nodes.append(node_ir) + for input in node.input: + value = self.lookup(input) + node_ir.inputs.append(value) + if value is not None: + value.uses.append(node_ir) + else: + # TODO(titaiwang): Do something more than warnings? + warnings.warn(f"Use of undefined variable '{input}'.", stacklevel=1) + for index, output in enumerate(node.output): + newvalue = ir.Value(name=output, node=node_ir, output_index=index) + if self._current_function is not None: + ir_value = self._function_shape_env.lookup( + self._current_function.original_function_proto, output + ) + if ir_value is not None: + newvalue.identity_merge_from(ir_value) + node_ir.outputs.append(newvalue) + self.bind(output, newvalue) + for attr in node.attribute: + attr_val = self.process_attribute(attr) + node_ir.attributes[attr.name] = attr_val + # Set constant-value for Constant node: + if node.op_type == "Constant" and node.domain in {"", "ai.onnx"}: + node_ir.outputs[0].value = utils.get_constant_node_value( + node, node.output[0] + ) + + def process_attribute( + self, attr: onnx.AttributeProto + ) -> ir.Graph | list[ir.Graph] | Any: + if attr.HasField("g"): + return self.visit_graph(attr.g) + elif len(attr.graphs) > 0: + return [self.visit_graph(graph) for graph in attr.graphs] + elif attr.ref_attr_name: + return ir.RefAttr(attr.name, attr.ref_attr_name, attr.type) + else: + # This returns Any based on onnx.helper.get_attribute_value's return type. + return onnx.helper.get_attribute_value(attr) + + def process_graph_output(self, output: onnx.ValueInfoProto): + value = self.lookup(output.name) + if value is None: + # TODO(titaiwang): Should we remove the non-output value from the graph.values? + warnings.warn( + f"Graph contains no definition for output '{output.name}'.", + stacklevel=1, + ) + else: + value.type = output.type + value.is_output = True + + def process_function_input(self, input: str): + ir_value = self._function_shape_env.lookup( + self._current_function.original_function_proto, input + ) + if ir_value is None: + ir_value = ir.Value(name=input) + self.bind(input, ir_value) + + def process_function_output(self, output: str): + value = self.lookup(output) + if value is None: + print( + f"WARNING: Function contains no definition for output '{output.name}'." + ) + else: + value.is_output = True + + def process_value_info(self, value_info: onnx.ValueInfoProto): + function_id, ir_value = self._function_shape_env.process_value_info(value_info) + existing_value = self.lookup(value_info.name) + if existing_value is not None: + existing_value.identity_merge_from(ir_value) + ir_value = existing_value + + if self._ir_version >= 10: # noqa: PLR2004 + # ONNX >= 1.16 where value_info can be defined in function + self.bind(ir_value.name, ir_value) + elif function_id is not None: + # All value_infos are defined in main graph + # This needs to be handled while visiting function, so do nothing here. + pass + else: + self.bind(ir_value.name, ir_value) + + +def build_ir(model: onnx.ModelProto): + """Builds an IR from an ONNX model proto.""" + return IRBuilder().visit_model(model) diff --git a/onnxscript/onnxrewriter/ir/irbuilder_test.py b/onnxscript/onnxrewriter/ir/irbuilder_test.py new file mode 100644 index 0000000000..52cb8068dd --- /dev/null +++ b/onnxscript/onnxrewriter/ir/irbuilder_test.py @@ -0,0 +1,198 @@ +import unittest + +import onnx.parser + +from onnxrewriter.ir import irbuilder + + +class IRBuilderTest(unittest.TestCase): + def test_irbuilder(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + three = Constant () + x_cube = Pow(x, three) + B = Constant () + x_cube_mul_B = Mul(x_cube, B) + sum = Add(x, x_cube_mul_B) + C = Constant () + C_times_sum = Mul(C, sum) + tanh = Tanh(C_times_sum) + one = Constant () + one_plus_tanh = Add(one, tanh) + half = Constant () + half_x = Mul(half, x) + z = Mul(one_plus_tanh, half_x) + } + """ + ) + irbuilder.build_ir(model) + + def test_shape_is_accessible_for_graph_value_with_value_info(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + + { + t = Add (x, y) + z = Add (t, x) + } + """ + ) + irmodel = irbuilder.build_ir(model) + self.assertEqual( + irmodel.graph.nodes[0].outputs[0].tensor_shape_proto(), + onnx.TensorShapeProto(dim=[onnx.TensorShapeProto.Dimension(dim_param="N")]), + ) + + def test_shape_is_accessible_for_function_value_with_experimental_value_info(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + + afunction (x, y) => (z) + { + o = MatMul (x, y) + shape = Constant () + z = Reshape (o, shape) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/o", onnx.TensorProto.FLOAT, ["N", "K"] + ) + ) + irmodel = irbuilder.build_ir(model) + self.assertEqual( + irmodel.functions[0].nodes[0].outputs[0].tensor_shape_proto(), + onnx.TensorShapeProto( + dim=[ + onnx.TensorShapeProto.Dimension(dim_param="N"), + onnx.TensorShapeProto.Dimension(dim_param="K"), + ] + ), + ) + + def test_function_input_is_correctly_linked_with_subnodes_in_function_when_shape_is_missing( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + irmodel = irbuilder.build_ir(model) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[0]) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[1]) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[0], irmodel.functions[0].values["x"] + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[1], irmodel.functions[0].values["y"] + ) + + def test_function_input_is_correctly_linked_with_subnodes_in_function_when_shape_is_present( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.extend( + [ + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/x", onnx.TensorProto.FLOAT, ["N"] + ), + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/y", onnx.TensorProto.FLOAT, ["M"] + ), + ] + ) + irmodel = irbuilder.build_ir(model) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[0]) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[1]) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[0], irmodel.functions[0].values["x"] + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[1], irmodel.functions[0].values["y"] + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[0].tensor_shape_proto(), + onnx.TensorShapeProto( + dim=[ + onnx.TensorShapeProto.Dimension(dim_param="N"), + ] + ), + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[1].tensor_shape_proto(), + onnx.TensorShapeProto( + dim=[ + onnx.TensorShapeProto.Dimension(dim_param="M"), + ] + ), + ) + + def test_out_of_context_value_reference_is_correct(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { + two = Constant () + z = If (cond) < + then_branch = then_graph () => (then_z) { + three = Constant () + temp = Add (two, three) + then_z = Mul (temp, x) + }, + else_branch = else_graph () => (else_z) { + four = Constant () + temp = Add (two, four) + else_z = Mul (temp, x) + } + > + } + """ + ) + irmodel = irbuilder.build_ir(model) + then_graph = irmodel.graph.nodes[1].attributes["then_branch"] + self.assertIsNotNone(then_graph.nodes[2].inputs[1]) + else_graph = irmodel.graph.nodes[1].attributes["else_branch"] + self.assertIsNotNone(else_graph.nodes[2].inputs[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/ir/protobuilder.py b/onnxscript/onnxrewriter/ir/protobuilder.py new file mode 100644 index 0000000000..32ea6006c9 --- /dev/null +++ b/onnxscript/onnxrewriter/ir/protobuilder.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import onnx +import onnx.helper +from onnx.helper import make_attribute + +from onnxrewriter import ir + + +class ModelProtoBuilder: + def __init__(self): + self.opset_imports: dict[str, onnx.OperatorSetIdProto] = {} + + def visit_ir_model(self, ir_model: ir.Model) -> onnx.ModelProto: + model_proto = onnx.ModelProto() + model_proto.ir_version = ir_model.original_model_proto.ir_version + # TODO (sbhokare) : Find a way of copying model properties without + # each property individually + # Copy over model properties + model_proto.doc_string = ir_model.original_model_proto.doc_string + model_proto.domain = ir_model.original_model_proto.domain + model_proto.metadata_props.extend(ir_model.original_model_proto.metadata_props) + model_proto.model_version = ir_model.original_model_proto.model_version + model_proto.producer_name = ir_model.original_model_proto.producer_name + model_proto.producer_version = ir_model.original_model_proto.producer_version + model_proto.training_info.extend(ir_model.original_model_proto.training_info) + + for domain, version in ir_model.version_map.items(): + operator_setid_proto = model_proto.opset_import.add() + operator_setid_proto.domain, operator_setid_proto.version = domain, version + self.opset_imports[domain] = operator_setid_proto + for function in ir_model.functions: + function_proto = model_proto.functions.add() + self.visit_ir_function(function, function_proto) + graph_proto = model_proto.graph + self.visit_ir_graph(ir_model.graph, graph_proto) + return model_proto + + def visit_ir_graph( + self, ir_graph: ir.Graph, graph_proto: onnx.GraphProto + ) -> onnx.GraphProto: + graph_proto.name = ir_graph.name + # Copy over graph properties + graph_proto.doc_string = ir_graph.original_graph_proto.doc_string + # graph_proto.metadata_props = ir_graph.original_graph_proto.metadata_props) + graph_proto.quantization_annotation.extend( + ir_graph.original_graph_proto.quantization_annotation + ) + + for node in ir_graph.nodes: + node_proto = graph_proto.node.add() + self.process_ir_node(node, node_proto) + for i in ir_graph.original_graph_proto.input: + graph_proto.input.append(i) + for o in ir_graph.original_graph_proto.output: + graph_proto.output.append(o) + for val in ir_graph.original_graph_proto.value_info: + graph_proto.value_info.append(val) + for i in ir_graph.original_graph_proto.initializer: # type: ignore[assignment] + graph_proto.initializer.append(i) # type: ignore[arg-type] + return graph_proto + + def visit_ir_function( + self, ir_function: ir.Function, function_proto: onnx.FunctionProto + ) -> onnx.FunctionProto: + function_proto.name = ir_function.name + function_proto.domain = ir_function.domain + # Copy over function properties + function_proto.doc_string = ir_function.original_function_proto.doc_string + # function_proto.metadata_props = ir_function.original_function_proto.metadata_props) + + for node in ir_function.nodes: + operator_setid_proto = function_proto.opset_import.add() + if node.domain in self.opset_imports: + operator_setid_proto.domain = self.opset_imports[node.domain].domain + operator_setid_proto.version = self.opset_imports[node.domain].version + else: + raise ValueError(f"Unknown domain {node.domain}") + node_proto = function_proto.node.add() + self.process_ir_node(node, node_proto) + # TODO (shubham) : Propagate shape-type info + for i in ir_function.original_function_proto.input: + function_proto.input.append(i) + for o in ir_function.original_function_proto.output: + function_proto.output.append(o) + for attr in ir_function.original_function_proto.attribute: + function_proto.attribute.append(attr) + for attr_proto in ir_function.original_function_proto.attribute_proto: + function_proto.attribute_proto.append(attr_proto) + for val in getattr(ir_function.original_function_proto, "value_info", []): + function_proto.value_info.append(val) + return function_proto + + def process_ir_node( + self, ir_node: ir.Node, node_proto: onnx.NodeProto + ) -> onnx.NodeProto: + node_proto.op_type = ir_node.op_type + node_proto.domain = ir_node.domain + # Copy over node properties + node_proto.name = ir_node.original_node_proto.name + node_proto.doc_string = ir_node.original_node_proto.doc_string + # node_proto.metadata_props = ir_node.original_node_proto.metadata_props) + + for i in ir_node.inputs: + node_proto.input.append(i.name if i is not None else "") + for o in ir_node.outputs: + assert o is not None + node_proto.output.append(o.name) + for attr in ir_node.attributes.items(): + attr_proto = self.process_attribute(attr) + node_proto.attribute.append(attr_proto) + return node_proto + + def process_attribute(self, attr): + attr_name, attr_val = attr + if isinstance(attr_val, ir.RefAttr): + return attr_val.to_proto() + if isinstance(attr_val, ir.Graph): + graph_proto = onnx.GraphProto() + attr_val = self.visit_ir_graph(attr_val, graph_proto) + attr_proto = make_attribute(attr_name, attr_val) + return attr_proto + + +def build_model_proto(model: ir.Model) -> onnx.ModelProto: + """Builds an ONNX model proto from an IR.""" + return ModelProtoBuilder().visit_ir_model(model) diff --git a/onnxscript/onnxrewriter/ir/protobuilder_test.py b/onnxscript/onnxrewriter/ir/protobuilder_test.py new file mode 100644 index 0000000000..e147404f99 --- /dev/null +++ b/onnxscript/onnxrewriter/ir/protobuilder_test.py @@ -0,0 +1,217 @@ +import unittest + +import numpy as np +import onnx.checker +import onnx.parser + +from onnxrewriter.ir import irbuilder, protobuilder +from onnxrewriter.rewriter import pattern +from onnxrewriter.rewriter.onnxruntime import instance_to_group_normalization + +op = pattern.onnxop + + +class ConcatSerializeTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def concat_pattern(x, y, axis): + seq = op.SequenceConstruct(x, y) + return op.ConcatFromSequence(seq, axis=axis) + + def concat(x, y, axis): + return op.Concat(x, y, axis=axis) + + return pattern.RewriteRule(concat_pattern, concat) + + def test_concat_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[M] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + # Tests related to IR + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + def test_concat_in_function_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = pkg.custom.afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + # Tests related to IR + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 1) + self.assertEqual(ir.functions[0].nodes[0].op_type, "Concat") + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + def test_concat_in_nested_function_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = pkg.custom.afunction (x, y) + } + + afunction (x, y) => (z) + { + z = pkg.custom.nestedfunction(x, y) + } + + nestedfunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + # Tests related to IR + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 2) + self.assertEqual(len(ir.functions[0].nodes), 1) + self.assertEqual(len(ir.functions[1].nodes), 1) + self.assertEqual(ir.functions[0].nodes[0].op_type, "nestedfunction") + self.assertEqual(ir.functions[1].nodes[0].op_type, "Concat") + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + +class ControlFlowSerializeTest(unittest.TestCase): + def test_conditional_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] y) + { + f = Constant () + t = Constant () + y1 = local.myfun (f, x) + y = local.myfun (t, y1) + } + + myfun (b, lx) => (ly) + { + ly = If (b) < + then_branch = g1 () => (float[N] z_then) + { + two = Constant () + z_then = Mul (lx, two) + }, + else_branch = g2 () => (float[N] z_else) + { + three = Constant () + z_else = Mul (lx, three) + } + > + } + """ + ) + ir = irbuilder.build_ir(model) + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + def test_function_attribute_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] y) + { + f = Constant () + t = Constant () + y1 = local.myfun (f, x) + y = local.myfun (t, y1) + } + + myfun (l, lx) => (ly) + { + ly = Mul (l, lx) + } + """ + ) + ir = irbuilder.build_ir(model) + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + function_proto = model_proto.functions[0] + self.assertEqual(function_proto.attribute, ["a"]) + self.assertEqual(len(function_proto.attribute_proto), 1) + b_attr_proto = function_proto.attribute_proto[0] + self.assertEqual(b_attr_proto.name, "b") + self.assertEqual(b_attr_proto.type, onnx.AttributeProto.INT) + self.assertEqual(b_attr_proto.i, 1) + + def test_com_microsoft_opset_is_supported_in_protobuilder(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + image_reshape = Reshape (image, shape_a) + instance_norm = InstanceNormalization (image_reshape, scale, B) + shape_b = Constant() + instance_norm_reshape = Reshape (instance_norm, shape_b) + mul_output = Mul (instance_norm_reshape, weight) + output = Add (mul_output, bias) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight = np.random.rand(320, 1, 1).astype(np.float16) + bias = np.random.rand(320, 1, 1).astype(np.float16) + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "scale", + onnx.TensorProto.FLOAT16, + [32], + np.ones(32, dtype=np.float16), + ), + onnx.helper.make_tensor( + "B", onnx.TensorProto.FLOAT16, [32], np.zeros(32, dtype=np.float16) + ), + onnx.helper.make_tensor( + "weight", onnx.TensorProto.FLOAT16, [320, 1, 1], weight + ), + onnx.helper.make_tensor( + "bias", onnx.TensorProto.FLOAT16, [320, 1, 1], bias + ), + ] + ) + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 1) + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/ir/visitor.py b/onnxscript/onnxrewriter/ir/visitor.py new file mode 100644 index 0000000000..841c666a5a --- /dev/null +++ b/onnxscript/onnxrewriter/ir/visitor.py @@ -0,0 +1,937 @@ +from __future__ import annotations + +import dataclasses +import logging +from typing import Any, Sequence + +import numpy as np +import onnx + +from onnxrewriter import ir +from onnxrewriter.utils.utils import ( + get_initializer_type, + is_control_flow_op, + normalize_domain, +) + +logger = logging.getLogger(__name__) + + +def _override_inferred_value_type_with_symbolic_value_type( + symbolic_value: ir.Value | None, + inferred_value: ir.Value | None, +) -> ir.Value | None: + if inferred_value is not None and symbolic_value is not None: + inferred_value.type = symbolic_value.type + if inferred_value is None: + inferred_value = symbolic_value + return inferred_value + + +def is_local_function_node( + node: onnx.NodeProto, functions: dict[ir.FunctionId, onnx.FunctionProto] +) -> bool: + return ir.get_function_id_from_node(node) in functions + + +class FunctionShapeEnv: + def __init__(self): + # Mapping from (domain, function_name, overload) to {value_name: ir_value} + self._function_values: dict[ir.FunctionId, dict[str, ir.Value]] = {} + + def load_from_model_proto(self, model_proto: onnx.ModelProto) -> None: + for value_info in model_proto.graph.value_info: + self.load_from_value_info(value_info) + + def save_to_model_proto(self, model_proto: onnx.ModelProto) -> None: + for ( + domain, + function_name, + overload, + ), named_ir_values in self._function_values.items(): + for ir_value in named_ir_values.values(): + if ( + value_info := self.save_to_value_info( + ir_value, domain, function_name, overload + ) + ) is not None: + model_proto.graph.value_info.append(value_info) + + def load_from_value_info(self, value_info: onnx.ValueInfoProto) -> None: + function_id, ir_value = self.process_value_info(value_info) + if function_id is not None: + logger.debug( + "Loads torch symbolic value info '%s'.", + value_info.name, + ) + self._function_values.setdefault(function_id, {})[ir_value.name] = ir_value + + def process_value_info( + self, value_info: onnx.ValueInfoProto + ) -> tuple[ir.FunctionId | None, ir.Value]: + name = value_info.name + if len(splits := name.split("/")) == 2: # noqa: PLR2004 + # Experimental function value info format. + # To be deprecated after ONNX 1.16, where value_info is introduced in FunctionProto. + function_id, value_name = splits + splits = function_id.split("::") + domain, function_name = splits[0], splits[1] + # 'overload' is introduced in ONNX 1.16, consider it as empty string prior to that. + # The code is for future proof, in case overload is encoded in this format. + overload = "" + if len(splits) == 3: # noqa: PLR2004 + overload = splits[2] + function_id = (domain, function_name, overload) + else: + # Standard main graph value info format. + function_id = None + value_name = name + return function_id, ir.Value(value_name, type=value_info.type) + + def save_to_value_info( + self, value: ir.Value, domain: str, function_name: str, overload: str + ) -> onnx.ValueInfoProto | None: + if overload != "": + raise NotImplementedError("Overload is not supported yet.") + function_id = f"{domain}::{function_name}" + + if value.type is not None: + return onnx.helper.make_value_info( + f"{function_id}/{value.name}", value.type + ) + return None + + def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | None: + """Lookup ir value of 'value_name' inside 'function'.""" + function_id = ir.get_function_id(function) + function_values = self._function_values.get(function_id) + if ( + function_values is None + or (ir_value := function_values.get(value_name)) is None + ): + logger.debug( + "Lookup Missed %s torch symbolic value info in function %s::%s.", + value_name, + function.domain, + function.name, + ) + return None + logger.debug( + "Lookup found %s torch symbolic value info in function %s::%s.", + value_name, + function.domain, + function.name, + ) + return ir_value + + def bind( + self, value: ir.Value, domain: str, function_name: str, overload: str + ) -> None: + """Bind ir value 'value' to 'value_name' inside 'function'.""" + function_id = (domain, function_name, overload) + self._function_values.setdefault(function_id, {})[value.name] = value + + def get_ir_values(self, function: onnx.FunctionProto) -> dict[str, ir.Value]: + """Get all ir values inside 'function'.""" + function_id = ir.get_function_id(function) + return self._function_values.get(function_id, {}) + + +class SubScope: + values: dict[str, ir.Value] + ref_attributes: dict[str, onnx.AttributeProto] + owner: onnx.GraphProto | onnx.FunctionProto + + def __init__(self, owner: onnx.GraphProto | onnx.FunctionProto): + self.values = {} + self.ref_attributes = {} + self.owner = owner + + def lookup(self, name: str) -> ir.Value | None: + return self.values.get(name) + + def bind(self, name: str, value: ir.Value) -> None: + self.values[name] = value + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + return self.ref_attributes.get(ref_attr_name) + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self.ref_attributes[ref_attr_name] = attr + + def readable_strs(self, indent: int = 0) -> list[str]: + indent_str = " " * indent + strs = [] + if isinstance(self.owner, onnx.GraphProto): + strs.append(f"Graph {self.owner.name}:") + else: + strs.append(f"Function {self.owner.name}:") + strs.append(" ir.Values:") + for name, value in self.values.items(): + strs.append(f" {name}: {value}") + strs.append(" RefAttributes:") + for name, attr in self.ref_attributes.items(): + strs.append(f" {name}: {attr}") + + return [f"{indent_str}{s}" for s in strs] + + def __str__(self) -> str: + return "\n".join(self.readable_strs()) + + +@dataclasses.dataclass +class Scope: + _sub_scopes: list[SubScope] = dataclasses.field(default_factory=list) + + def lookup(self, name: str) -> ir.Value | None: + """Lookup value by name from all SubScopes.""" + for sub_scope in reversed(self._sub_scopes): + if (result := sub_scope.lookup(name)) is not None: + return result + return None + + def bind(self, name: str, value: ir.Value) -> None: + """Bind value to name in the most recent SubScope.""" + if name == "": + raise ValueError("Cannot bind to empty name.") + if value is None: + raise ValueError(f"Cannot bind None to value {name}.") + self._sub_scopes[-1].bind(name, value) + + def lookup_or_create(self, name: str) -> ir.Value: + """Lookup value by name from all SubScopes. If not found, create a new one in most recent SubScope.""" + if name == "": + raise ValueError("Cannot lookup or create empty name.") + for sub_scope in reversed(self._sub_scopes): + if (result := sub_scope.lookup(name)) is not None: + return result + value = ir.Value(name=name) + self.bind(name, value) + return value + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + for sub_scope in reversed(self._sub_scopes): + if (result := sub_scope.lookup_ref_attribute(ref_attr_name)) is not None: + return result + return None + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self._sub_scopes[-1].bind_ref_attribute(ref_attr_name, attr) + + def enter_sub_scope(self, owner: onnx.GraphProto) -> None: + self._sub_scopes.append(SubScope(owner)) + + def exit_sub_scope(self) -> SubScope: + return self._sub_scopes.pop() + + def current_function_scope(self) -> SubScope | None: + if len(self._sub_scopes) == 0: + return None + if isinstance(self._sub_scopes[0].owner, onnx.FunctionProto): + return self._sub_scopes[0] + return None + + def current_function(self) -> onnx.FunctionProto | None: + current_function_scope = self.current_function_scope() + if current_function_scope is not None: + return current_function_scope.owner + return None + + def current_graph(self) -> onnx.GraphProto | None: + for sub_scope in reversed(self._sub_scopes): + if isinstance(sub_scope.owner, onnx.GraphProto): + return sub_scope.owner + return None + + def readable_strs(self, indent: int = 0) -> list[str]: + indent_str = " " * indent + strs = [] + for i, sub_scope in enumerate(self._sub_scopes): + strs.append(f"SubScope {i}:") + strs.extend(sub_scope.readable_strs(indent=indent + 2)) + return [f"{indent_str}{s}" for s in strs] + + def __str__(self) -> str: + return "\n".join(self.readable_strs()) + + +@dataclasses.dataclass +class ScopeStack: + """Stack of scopes. + + Each Scope represents statically-nested SubScopes (where inner SubScopes can access names defined in outer SubScopes) + produced by subgraphs (occurring as attribute values), except for the first SubScope which could be produced by a function. + With a ScopeStack, there is no such possibility of referencing variables defined higher up in the stack by name. + Instead, it is meant to represent a sequence of (nested) function-calls. Each entry in the stack (except the outermost) + represents a call to a function. + + Thus, we would use a ScopeStack for a context-sensitive analysis (where we recursively process a called function). + For a context-insensitive analysis, we would only need a Scope (where we recursively process subgraphs). + + To debug, `print(scope_stack)` will print the scope structure as well as the info stored + in each scope. + """ + + _scopes: list[Scope] = dataclasses.field(default_factory=lambda: [Scope()]) + + def current_scope(self) -> Scope: + return self._scopes[-1] + + def lookup(self, name: str) -> ir.Value | None: + """Lookup value by name from the current Scope.""" + return self.current_scope().lookup(name) + + def bind(self, name: str, value: ir.Value) -> None: + """Bind value to name in the current Scope.""" + self.current_scope().bind(name, value) + + def lookup_or_create(self, name: str) -> ir.Value: + """Lookup value by name from the current Scope. If not found, create a new one.""" + return self.current_scope().lookup_or_create(name) + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + return self.current_scope().lookup_ref_attribute(ref_attr_name) + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self.current_scope().bind_ref_attribute(ref_attr_name, attr) + + def enter_graph_scope(self, graph: onnx.GraphProto) -> None: + self.current_scope().enter_sub_scope(graph) + + def exit_graph_scope(self) -> SubScope: + sub_scope = self.current_scope().exit_sub_scope() + assert isinstance(sub_scope.owner, onnx.GraphProto), "Expected graph scope." + return sub_scope + + def enter_function_scope(self, function: onnx.FunctionProto) -> None: + self._scopes.append(Scope()) + self.current_scope().enter_sub_scope(function) + + def exit_function_scope(self) -> SubScope: + sub_scope = self.current_scope().exit_sub_scope() + assert isinstance( + sub_scope.owner, onnx.FunctionProto + ), "Expected function scope." + self._scopes.pop() + return sub_scope + + def current_function(self) -> onnx.FunctionProto | None: + return self.current_scope().current_function() + + def current_graph(self) -> onnx.GraphProto | None: + return self.current_scope().current_graph() + + def __str__(self) -> str: + strs = ["ScopeStach:"] + for i, scope in enumerate(self._scopes): + strs.append(f" Scope {i}:") + strs.extend(scope.readable_strs(indent=2)) + return "\n".join(strs) + + +class ProtoVisitorCore: + def visit_model(self, model: onnx.ModelProto): + self.process_model(model) + for opset in model.opset_import: + self.process_opset_import(opset) + self.visit_graph(model.graph) + for function in model.functions: + self.visit_function(function) + + def process_model(self, model: onnx.ModelProto): + pass + + def process_opset_import(self, opset: onnx.OperatorSetIdProto): + pass + + def visit_graph(self, graph: onnx.GraphProto): + self.enter_scope(graph) + self.process_graph(graph) + for input in graph.input: + self.process_graph_input(input) + for init in graph.initializer: + self.process_initializer(init) + for value_info in graph.value_info: + self.process_value_info(value_info) + for node in graph.node: + self.visit_node(node) + for output in graph.output: + self.process_graph_output(output) + self.exit_scope(graph) + + def visit_function(self, function: onnx.FunctionProto): + self.enter_function_scope(function) + self.process_function(function) + for input in function.input: + self.process_function_input(input) + for node in function.node: + self.visit_node(node) + for output in function.output: + self.process_function_output(output) + self.exit_function_scope(function) + + def process_function_input(self, input: str): + pass + + def process_function_output(self, output: str): + pass + + def process_function(self, function: onnx.FunctionProto): + pass + + def enter_function_scope(self, function: onnx.FunctionProto): + pass + + def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope: + pass + + def enter_scope(self, graph: onnx.GraphProto): + pass + + def process_graph(self, graph: onnx.GraphProto): + pass + + def exit_scope(self, graph: onnx.GraphProto) -> SubScope: + pass + + def process_graph_input(self, input: onnx.ValueInfoProto): + pass + + def process_initializer(self, init: onnx.TensorProto): + pass + + def process_value_info(self, value_info: onnx.ValueInfoProto): + pass + + def visit_node(self, node: onnx.NodeProto): + self.process_node(node) + for attr in node.attribute: + self.visit_attribute(attr) + + def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + pass + + def process_graph_output(self, output: onnx.ValueInfoProto): + pass + + def visit_attribute(self, attr: onnx.AttributeProto): + self.process_attribute(attr) + if attr.HasField("g"): + self.visit_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + self.visit_graph(graph) + + def process_attribute(self, attr: onnx.AttributeProto): + pass + + +class ProtoVisitor(ProtoVisitorCore): + def __init__( + self, external_data_folder: str = "", *, do_shape_inference: bool = False + ) -> None: + super().__init__() + self.scopes = ScopeStack() + self.function_shape_env = FunctionShapeEnv() + self.version_map = {} # Map from domain to version + self.do_shape_inference = do_shape_inference + self.external_data_folder = external_data_folder + self.modified = False + + def process_opset_import(self, opset: onnx.OperatorSetIdProto): + domain = normalize_domain(opset.domain) + self.version_map[domain] = opset.version + + def lookup_version(self, domain: str) -> int: + domain = normalize_domain(domain) + return self.version_map.get(domain, 1) # TODO: handle missing domain + + def lookup(self, name: str) -> ir.Value | None: + if name == "": + return None + if (result := self.scopes.lookup(name)) is None: + logger.debug("Lookup value %s unfound.", name) + raise ValueError( + f"Undefined variable {name}.\n" + f"Available variables: {self.scopes.current_scope()}" + ) + logger.debug("Lookup value %s. Shape %s", name, result.tensor_shape_proto()) + return result + + def bind(self, name: str, value: ir.Value) -> None: + logger.debug("Binding value %s. Shape %s", name, value.tensor_shape_proto()) + self.scopes.bind(name, value) + + def lookup_or_create(self, name: str) -> ir.Value: + return self.scopes.lookup_or_create(name) + + def has_input(self, node: onnx.NodeProto, index: int) -> bool: + return index < len(node.input) and node.input[index] != "" + + # TODO: Cleanup handling of undefined variables. May fail in some of methods below. + + def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: + if index < len(node.input): + return self.lookup(node.input[index]) + return None + + def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: + info = self.get_input(node, index) + return info.type if info is not None else None + + def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: + info = self.get_input(node, index) + return info.element_type if info is not None else None + + def input_shape( + self, node: onnx.NodeProto, index: int + ) -> onnx.TensorShapeProto | None: + info = self.get_input(node, index) + return info.tensor_shape_proto() if info is not None else None + + def input_const_value(self, node: onnx.NodeProto, index: int) -> Any: + if not self.has_input(node, index): + return None # This is treated as a known constant value "None" + info = self.get_input(node, index) + return info.value + + def has_output(self, node: onnx.NodeProto, index: int) -> bool: + return index < len(node.output) and node.output[index] != "" + + def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: + if index < len(node.output): + return self.lookup(node.output[index]) + return None + + def get_input_value( + self, node: onnx.NodeProto, index: int, default: Any | None = None + ) -> Any | None: + info = self.get_input(node, index) + if info is not None: + return info.value + return default + + def get_input_type( + self, node: onnx.NodeProto, index: int, default: onnx.TypeProto | None = None + ) -> onnx.TypeProto | None: + info = self.get_input(node, index) + if info is not None: + return info.type + return default + + def enter_scope(self, graph: onnx.GraphProto): + logger.debug("enter_scope: graph %s", graph.name) + self.scopes.enter_graph_scope(graph) + + def exit_scope(self, graph: onnx.GraphProto) -> SubScope: + logger.debug("exit_scope: graph %s", graph.name) + return self.scopes.exit_graph_scope() + + def enter_function_scope(self, function: onnx.FunctionProto): + logger.debug("enter_function_scope: function %s", function.name) + self.scopes.enter_function_scope(function) + ir_values = self.function_shape_env.get_ir_values(function) + for name, ir_value in ir_values.items(): + inferred_ir_value = self.lookup_or_create(name) + updated_ir_value = _override_inferred_value_type_with_symbolic_value_type( + ir_value, inferred_ir_value + ) + self.bind(name, updated_ir_value) + + def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope: + logger.debug("exit_function_scope: function %s", function.name) + # Sync ir value back to function_shape_env + function_scope = self.scopes.exit_function_scope() + for ir_value in function_scope.values.values(): + self.function_shape_env.bind(ir_value, *ir.get_function_id(function)) + return function_scope + + def process_initializer(self, init: onnx.TensorProto): + array = onnx.numpy_helper.to_array(init, self.external_data_folder) + self.bind( + init.name, + ir.Value(name=init.name, value=array, type=get_initializer_type(init)), + ) + + def process_graph_input(self, input: onnx.ValueInfoProto): + self.bind(input.name, ir.Value(name=input.name, type=input.type)) + + def process_value_info(self, value_info: onnx.ValueInfoProto): + logger.debug("process_value_info: %s", value_info) + value = self.lookup_or_create(value_info.name) + value.type = value_info.type + # Populate function shape environment + self.function_shape_env.load_from_value_info(value_info) + + def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + output_types = {} + if self.do_shape_inference and not is_control_flow_op(node): + # Control-flow ops are more complicated. Not supported here yet. + # TODO: handle optional inputs + def get_constant_value(i: int) -> onnx.TensorProto | None: + value = self.input_const_value(node, i) + if isinstance(value, np.ndarray) and value.size < 20: # noqa: PLR2004 + return onnx.numpy_helper.from_array(value, node.input[i]) + return None + + input_types = { + x: self.input_type(node, i) for i, x in enumerate(node.input) + } + input_data = {x: get_constant_value(i) for i, x in enumerate(node.input)} + input_data = {k: v for k, v in input_data.items() if v is not None} + if any(t is None for t in input_types.values()): + logger.debug( + "Skipping shape inference for node %s due to missing input type.", + node.name, + ) + else: + # TODO: pass in constant values, ir_version + try: + schema = onnx.defs.get_schema( + node.op_type, self.lookup_version(node.domain), node.domain + ) + output_types = onnx.shape_inference.infer_node_outputs( + schema, node, input_types, input_data + ) + except Exception as e: # noqa: BLE001 + logger.debug( + "Skipping shape inference for node %s due to exception: %s", + node.name, + e, + ) + + for output in node.output: + info = self.lookup_or_create(output) + if output in output_types: + # TODO: merge types + info.type = output_types[output] + + +class ProtoTransformer(ProtoVisitor): + # TODO(lowpri) Practically this is useless. + # Subgraph only exist in 'if' nodes. 'if' nodes only exist in torchlib functions. + # There is no pre-existing value_info in torchlib functions. + # def exit_scope(self, graph: onnx.GraphProto) -> SubScope: + # # Also sync updated ir values back to value_info in graph. + # sub_scope = super().exit_scope(graph) + + def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None: + replacement = self.process_node(node) + logger.debug( + "visit_node: %s::%s %s replacement %s", + node.domain, + node.op_type, + node.name, + "found" if replacement is not None else "missed", + ) + if replacement is None: + # No change. Process attributes. + for attr in node.attribute: + self.visit_attribute(attr) + return None + else: + self.modified = True + # We recursively visit the replacement nodes. + result = [] + for newnode in replacement: + n = self.visit_node(newnode) + if n is not None: + result.extend(n) + else: + result.append(newnode) + return result + + def visit_graph(self, graph: onnx.GraphProto) -> dict[str, ir.Value]: + self.enter_scope(graph) + self.process_graph(graph) + for input in graph.input: + self.process_graph_input(input) + for init in graph.initializer: + self.process_initializer(init) + for value_info in graph.value_info: + self.process_value_info(value_info) + updates = [] + nodes = graph.node + for i, node in enumerate(nodes): + replacement = self.visit_node(node) + if replacement is not None: + updates.append((i, replacement)) + for i, replacement in reversed(updates): + old_node_name = nodes[i].name + del nodes[i] + for newnode in reversed(replacement): + logger.debug( + "Replacement node %s for %s. Size %s", + newnode.name, + old_node_name, + newnode.ByteSize(), + ) + nodes.insert(i, newnode) + for output in graph.output: + self.process_graph_output(output) + return self.exit_scope(graph) + + +class FunctionCallsiteAnalysis(ProtoVisitor): + """Collects the callsites of each function.""" + + def __init__(self): + super().__init__() + self.functions: dict[ir.FunctionId, onnx.FunctionProto] = {} + self.function_calls: dict[ir.FunctionId, list[onnx.NodeProto]] = {} + + def visit_function(self, function: onnx.FunctionProto): + # Do not visit function via model.functions. + # Only visit function at callsites. + # The purpose of this analysis is to collect the callsites of each function. + pass + + def visit_node(self, node: onnx.NodeProto) -> None: + if is_local_function_node(node, self.functions): + function_id = ir.get_function_id_from_node(node) + self.function_calls.setdefault(function_id, []).append(node) + for subnode in self.functions[function_id].node: + self.visit_node(subnode) + + def visit_model(self, model: onnx.ModelProto) -> None: + for function in model.functions: + self.functions[ir.get_function_id(function)] = function + + super().visit_model(model) + + +class FunctionRenamer: + _POSTFIX_FORMAT = "{name}|{postfix}_{count}" + + def __init__(self, postfix="folded"): + self._function_key_to_instance_count = {} + self._postfix = postfix + + def rename(self, function: onnx.FunctionProto) -> None: + domain = function.domain + name = function.name + key = (domain, name) + self._function_key_to_instance_count.setdefault(key, 0) + function.name = self._POSTFIX_FORMAT.format( + name=name, + postfix=self._postfix, + count=self._function_key_to_instance_count[key], + ) + self._function_key_to_instance_count[key] += 1 + + +class FunctionCallsiteProtoTransformer(ProtoTransformer): + """Unlike other base visitors, this is a special visitor that visits functions at their callsite. + + This allows transforming and constructing specialized functions based on callsite context. + """ + + _functions: dict[ir.FunctionId, onnx.FunctionProto] + _function_callsites: dict[ir.FunctionId, list[onnx.NodeProto]] + _new_functions: list[onnx.FunctionProto] + _function_renamer: FunctionRenamer + + def _gather_function_metadata(self, model: onnx.ModelProto): + analysis = FunctionCallsiteAnalysis() + analysis.visit_model(model) + self._functions = analysis.functions + self._function_callsites = analysis.function_calls + self._new_functions = [] + self._function_renamer = FunctionRenamer() + + def process_function_outputs(self, function: onnx.FunctionProto) -> bool: + """Process function outputs. + + This method is called when a function is visited at its callsite. + + Returns: + True if the function outputs are modified. + """ + del function # Unused + return False + + def process_function_node_outputs( + self, + node: onnx.NodeProto, + function_scope: SubScope, + ) -> None: + """Fetch value infos of function output to re-bind them for function node output.""" + function = function_scope.owner + output_values = [function_scope.lookup(output) for output in function.output] + for actual_name, formal_value in zip(node.output, output_values): + if formal_value is None: + raise RuntimeError( + "Missing output %s in function-call to %s", + actual_name, + node.op_type, + ) + actual_value = self.lookup_or_create(actual_name) + actual_value.identity_merge_from(formal_value) + if logger.level <= logging.INFO: + logger.info( + "Binding outputs for function %s. %s => %s", + function.name, + actual_value, + node.output, + ) + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + return self.scopes.lookup_ref_attribute(ref_attr_name) + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self.scopes.bind_ref_attribute(ref_attr_name, attr) + + def visit_model(self, model: onnx.ModelProto): + self._gather_function_metadata(model) + + self.process_model(model) + for opset in model.opset_import: + self.process_opset_import(opset) + self.visit_graph(model.graph) + + for new_function in self._new_functions: + model.functions.append(new_function) + + self.function_shape_env.save_to_model_proto(model) + + def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None: + if is_local_function_node(node, self._functions): + function_id = ir.get_function_id_from_node(node) + if function_id not in self._functions: + # Do not recursively visit new functions. + return None + replacement, _ = self.process_function_node(node) + else: + replacement = self.process_node(node) + logger.debug( + "visit_node: %s::%s %s replacement %s", + node.domain, + node.op_type, + node.name, + "found" if replacement is not None else "missed", + ) + if replacement is None: + # No change. Process attributes. + for attr in node.attribute: + self.visit_attribute(attr) + return None + else: + self.modified = True + # We recursively visit the replacement nodes. + result = [] + for newnode in replacement: + n = self.visit_node(newnode) + if n is not None: + result.extend(n) + else: + result.append(newnode) + return result + + def process_function_node( + self, node: onnx.NodeProto + ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: + function_id = ir.get_function_id_from_node(node) + function = self._functions[function_id] + + is_unique_callsite = len(self._function_callsites[function_id]) == 1 + if not is_unique_callsite: + mutable_function = onnx.FunctionProto() + mutable_function.CopyFrom(function) + else: + mutable_function = function + + logger.info("Visit function %s node %s", function_id, node.name) + actual_input_value_infos = [self.lookup(input) for input in node.input] + # Handle omitted inputs, these are considered optional inputs of the function. + actual_input_value_infos.extend( + [None] * (len(function.input) - len(actual_input_value_infos)) + ) + ref_attributes = { + attr_proto.name: self.lookup_ref_attribute(attr_proto.ref_attr_name) + for attr_proto in node.attribute + if attr_proto.ref_attr_name + } + + self.enter_function_scope(mutable_function) + if logger.level <= logging.INFO: + printable_actual_input_value_infos = [ + str(x) for x in actual_input_value_infos + ] + logger.info( + "Actual input value infos: %s", + printable_actual_input_value_infos, + ) + logger.info("Enter function scope: %s", self.scopes.current_scope()) + + logger.debug("Binding inputs for function %s", function.name) + for actual_input_value_info, formal_input in zip( + actual_input_value_infos, function.input + ): + formal_info = ir.Value(formal_input) + if actual_input_value_info is not None: + formal_info.identity_merge_from(actual_input_value_info) + self.bind(formal_input, formal_info) + + for attr_proto in function.attribute_proto: + # Default value of function attributes. + self.bind_ref_attribute(attr_proto.name, attr_proto) + + for attr_proto in node.attribute: + if attr_proto.ref_attr_name: + concrete_attribute = ref_attributes.get(attr_proto.name) + if concrete_attribute is None: + continue + self.bind_ref_attribute(attr_proto.name, concrete_attribute) + else: + self.bind_ref_attribute(attr_proto.name, attr_proto) + + # Visit inner function nodes. + node_updates: list[tuple[int, list[onnx.NodeProto]]] = [] + nodes = mutable_function.node + for i, inner_node in enumerate(nodes): + replacement = self.visit_node(inner_node) + if replacement is not None: + node_updates.append((i, replacement)) + for i, replacement in reversed(node_updates): + old_node_name = nodes[i].name + old_node_op_type = nodes[i].op_type + del nodes[i] + for newnode in reversed(replacement): + logger.debug( + "Replacement node inside function %s: %s for %s %s. Size %s", + node.name, + newnode.output, + old_node_name, + old_node_op_type, + newnode.ByteSize(), + ) + nodes.insert(i, newnode) + added_domains = set() + del mutable_function.opset_import[:] + for inner_node in nodes: + # Update opset_import if needed. + if inner_node.domain not in added_domains: + version = self.lookup_version(inner_node.domain) + mutable_function.opset_import.append( + onnx.OperatorSetIdProto(domain=inner_node.domain, version=version) + ) + added_domains.add(inner_node.domain) + + output_updates = self.process_function_outputs(mutable_function) + + is_new_function = not is_unique_callsite and (node_updates or output_updates) + if is_new_function: + self._new_functions.append(mutable_function) + self._function_renamer.rename(mutable_function) + node.op_type = mutable_function.name + + function_scope = self.exit_function_scope(mutable_function) + + self.process_function_node_outputs(node, function_scope) + + logger.info("Exit function scope: %s", function_scope) + logger.info("Exit function %s node %s", function_id, node.name) + + if is_new_function: + return [node], mutable_function + return None, None diff --git a/onnxscript/onnxrewriter/ir/visitor_test.py b/onnxscript/onnxrewriter/ir/visitor_test.py new file mode 100644 index 0000000000..8596b4f3bf --- /dev/null +++ b/onnxscript/onnxrewriter/ir/visitor_test.py @@ -0,0 +1,38 @@ +import unittest + +import onnx + +from onnxrewriter.ir import visitor + + +class FunctionCallsiteProtoTransformerTest(unittest.TestCase): + def test_function_optional_input_is_recorded_by_shape_env(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + z = custom.function(x) + } + < + domain: "custom", + opset_import: ["" : 18] + > + function (x, optional_y, optional_z) => (return_val) + { + return_val = custom.custom_op (x, optional_y, optional_z) + } + """ + ) + + model_visitor = visitor.FunctionCallsiteProtoTransformer() + model_visitor.visit_model(model) + self.assertIsNotNone( + model_visitor.function_shape_env.lookup(model.functions[0], "optional_y") + ) + self.assertIsNotNone( + model_visitor.function_shape_env.lookup(model.functions[0], "optional_z") + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/optimizer/__init__.py b/onnxscript/onnxrewriter/optimizer/__init__.py new file mode 100644 index 0000000000..94653d7605 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/__init__.py @@ -0,0 +1,110 @@ +import logging +from typing import Any + +import onnx + +from onnxrewriter import rewriter +from onnxrewriter.optimizer.constant_folding import fold_constants +from onnxrewriter.optimizer.copy_propagation import ( + do_copy_propagation, + do_sequence_simplification, +) +from onnxrewriter.optimizer.remove_unused import remove_unused_nodes +from onnxrewriter.optimizer.remove_unused_function import remove_unused_functions +from onnxrewriter.optimizer.simple_function_folding import ( + inline_functions_with_unused_outputs, + inline_simple_functions, +) +from onnxrewriter.rewriter import ( + broadcast_to_matmul, + cast_constant_of_shape, + gemm_to_matmul_add, + no_op, +) + +logger = logging.getLogger(__name__) + + +def optimize( + model: onnx.ModelProto, + num_iterations: int = 2, + *, + onnx_shape_inference: bool = True, + stop_if_no_change: bool = True, + external_data_folder: str = "", + **kwargs: Any, +) -> onnx.ModelProto: + """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. + + Args: + model (onnx.ModelProto): The model to optimize. + num_iterations (int, optional): Number of iterations to perform. + onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. + Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. + This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries + the symbolic shapes recorded from dynamo tracing. + stop_if_no_change (bool, optional): Whether to stop if no change is detected. + external_data_folder (str, optional): The folder to store external data. + **kwargs: Additional keyword arguments. For BC purposes. + """ + if kwargs.pop("function_aware_folding", None) is not None: + logger.warning( + "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " + "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " + "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " + "See 'onnx_shape_inference' for more details." + ) + for _ in range(num_iterations): + if onnx_shape_inference: + model = onnx.shape_inference.infer_shapes( + model, check_type=True, strict_mode=True, data_prop=True + ) + + inline_simple_functions(model) + modified = fold_constants( + model, external_data_folder, onnx_shape_inference=onnx_shape_inference + ) + + remove_unused_nodes(model) + inline_simple_functions(model) + remove_unused_functions(model) + inline_functions_with_unused_outputs(model) + # NOTE: This is general rewrite rules + model = rewriter.rewrite( + model, + pattern_rewrite_rules=[ + *no_op.rules.rules, # TODO: merge this rule into constant folding? + *broadcast_to_matmul.rules.rules, + gemm_to_matmul_add.rule, + *cast_constant_of_shape.rules.rules, + ], + ) + if stop_if_no_change and not modified: + logger.debug("Stopping after %d iterations.", _) + break + + for node in model.graph.node: + logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) + + for function in model.functions: + for node in function.node: + logger.debug( + "Function %s::%s node %s::%s name %s.", + function.domain, + function.name, + node.domain, + node.op_type, + node.name, + ) + + # do_sequence_simplification(model) + return model + + +__all__ = [ + "fold_constants", + "remove_unused_nodes", + "optimize", + "do_copy_propagation", + "do_sequence_simplification", +] diff --git a/onnxscript/onnxrewriter/optimizer/constant_folding.py b/onnxscript/onnxrewriter/optimizer/constant_folding.py new file mode 100644 index 0000000000..ce26c5d722 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/constant_folding.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import logging +from typing import Any, Sequence + +import numpy as np +import onnx +import onnx.reference.ops + +from onnxrewriter import ir +from onnxrewriter.ir import visitor +from onnxrewriter.optimizer import evaluator +from onnxrewriter.utils.utils import ( + is_control_flow_op, + is_onnx_domain, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 + +# Ops excluded from constant-propagation: +# * Random ops, which are not deterministic (checked below) +# * Control flow ops (checked by presence of graph-attribute) + +non_deterministic_ops = frozenset( + { + "RandomUniform", + "RandomNormal", + "RandomUniformLike", + "RandomNormalLike", + "Multinomial", + } +) + +onnx_domain = frozenset({"", "onnx.ai"}) + + +def is_non_deterministic_op(node: onnx.NodeProto) -> bool: + return node.op_type in non_deterministic_ops and is_onnx_domain(node.domain) + + +def is_constant_op(node: onnx.NodeProto) -> bool: + return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain( + node.domain + ) + + +class ConstantFolder(visitor.FunctionCallsiteProtoTransformer): + def __init__( + self, + registry: evaluator.PartialEvaluatorRegistry, + external_data_folder: str, + *, + do_shape_inference: bool, + ) -> None: + self.registry = registry + # TODO: make evaluator a parameter + self.evaluate = evaluator.reference_evaluator.evaluate + self._do_shape_inference = do_shape_inference + self._init() + super().__init__(external_data_folder, do_shape_inference=do_shape_inference) + + def _init(self) -> None: + self.counts = {} + self.sizes = {} + + def add_count(self, op: str, size: int = 1): + self.counts[op] = self.counts.get(op, 0) + 1 + self.sizes[op] = self.sizes.get(op, 0) + size + + def foldable_value(self, name: str, value): + """Checks if a runtime-constant can and should be folded into the graph. + + We fold constants only if they are tensors (not lists of tensors, for example) + and have size below desired limit. + """ + if value is ir.NotConstant: + return None + + if not isinstance(value, np.ndarray): + # ONNX does not have a way to represent non-tensor constants, eg. a sequence. + # So, a constant-value of type sequence is not folded, but it can be used + # to optimize subsequent operations when possible. + logger.warning( + "Skip storing constant folded value %s due to unsupported type %s.", + name, + type(value), + ) + return None + + if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: + logger.warning( + "Skip storing constant folded nvalue %s due to large size %s.", + name, + value.nbytes, + ) + return None + + return onnx.numpy_helper.from_array(value, name) + + def new_constant(self, name, value): + if isinstance(value, (int, float, np.ScalarType)): + value = np.array(value) + + info = self.lookup_or_create(name) + info.value = value + + tensor = self.foldable_value(name, value) + if tensor is None: + return None + + logger.debug( + "New constant for value %s dtype: %s shape: %s", + name, + value.dtype, + value.shape, + ) + info.type = onnx.helper.make_tensor_type_proto( + onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape + ) + node = onnx.helper.make_node( + "Constant", inputs=[], outputs=[name], value=tensor + ) + return [node] + + def convert_attributes( + self, attributes: Sequence[onnx.AttributeProto] + ) -> dict[str, Any]: + if self.scopes.current_scope().current_function_scope(): + # Need to resolve ref_attr_name if inside a function. + attr_dict = {} + for attribute in attributes: + concrete_attribute = ( + self.lookup_ref_attribute(attribute.ref_attr_name) + if attribute.ref_attr_name + else attribute + ) + if concrete_attribute is None: + continue + attr_dict[attribute.name] = onnx.helper.get_attribute_value( + concrete_attribute + ) + return attr_dict + return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} + + def replace_copy(self, node: onnx.NodeProto) -> None: + for i in range(len(node.input)): + input = self.get_input(node, i) + if input is not None and input.is_copy(): + old_value = self.lookup_or_create(input.name) + assert isinstance(input.symbolic_value, str) + new_value = self.lookup_or_create(input.symbolic_value) + # Merge meta info. It is important to do if the new value + # is created by evaluator, and thus carries zero meta info. + # Since this is a copy, the meta info should be the same. + new_value.identity_merge_from(old_value) + node.input[i] = input.symbolic_value + + def process_function_outputs(self, function: onnx.FunctionProto) -> bool: + # Resolve copy for function subgraph output. + # Avoid copy of function subgraph input, because it is illegal for a direct edge + # from function input to function output. + prohibited_value_set = set(function.input) + updated = False + for i, output_name in enumerate(function.output): + output = self.lookup(output_name) + if ( + output is not None + and output.is_copy() + and output.symbolic_value not in prohibited_value_set + ): + old_value = self.lookup_or_create(output.name) + assert isinstance(output.symbolic_value, str) + new_value = self.lookup_or_create(output.symbolic_value) + new_value.identity_merge_from(old_value) + function.output[i] = output.symbolic_value + updated = True + return updated + + def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + self.replace_copy(node) + + super().process_node(node) + + inputs = [self.lookup(x) for x in node.input] + attrs = self.convert_attributes(node.attribute) + + domain = node.domain + op = node.op_type + version = self.lookup_version(domain) + + # if any(x is Undefined for x in inputs): + # return None + # Above check ensures that none of the optimizations below need to handle + # undefined inputs + + op_optimizers = self.registry.lookup_evaluators(domain, op, version) + for optimizer in op_optimizers: + assert optimizer + output = optimizer(self, node) + if output is None: + continue + if isinstance(output, list): + return output + else: + # Currently handles single output only + self.add_count(node.op_type, output.size) + return self.new_constant(node.output[0], output) + + if is_control_flow_op(node) or is_non_deterministic_op(node): + return None + + input_values = [x.value if x is not None else None for x in inputs] + if any(x is ir.NotConstant for x in input_values): + return None + + outputs = self.evaluate(domain, op, version, *input_values, **attrs) + # TODO: what if evaluated value is None? + if outputs is None: + return None + if len(node.output) == 1 and not isinstance(outputs, (tuple, list)): + replacement = self.new_constant(node.output[0], outputs) + if is_constant_op(node): + return None + self.add_count(op, outputs.size) + return replacement + else: + logger.warning( + "Skipping constant folding for op %s with multiple outputs.", op + ) + return None + + def process_function_node( + self, node: onnx.NodeProto + ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: + self.replace_copy(node) + + _, new_function = super().process_function_node(node) + + # Replace function node with Constant if all outputs are constants + ir_values = [self.lookup(output_name) for output_name in node.output] + tensors = [ + self.foldable_value( + output_name, ir_value.value if ir_value is not None else None + ) + for output_name, ir_value in zip(node.output, ir_values) + ] + if all(tensor is not None for tensor in tensors): + replacements = [] + for output_name, tensor in zip(node.output, tensors): + newnode = onnx.helper.make_node( + "Constant", inputs=[], outputs=[output_name], value=tensor + ) + replacements.append(newnode) + logger.debug( + "Function node replacements: node %s %s (%s/%s)", + node.name, + [replacement.output for replacement in replacements], + len(replacements), + len(node.output), + ) + return replacements, new_function + return None, new_function + + def visit_model(self, model: onnx.ModelProto) -> None: + self._init() + + super().visit_model(model) + + +def fold_constants( + model: onnx.ModelProto, + external_data_folder: str = "", + *, + onnx_shape_inference: bool = False, +) -> bool: + """Returns true iff the model was modified.""" + folder = ConstantFolder( + evaluator.registry, + external_data_folder, + do_shape_inference=onnx_shape_inference, + ) + folder.visit_model(model) + for op in folder.counts: + logger.info( + "Constant-folded '%s' %s times, with %s size.", + op, + folder.counts[op], + folder.sizes[op], + ) + return folder.modified diff --git a/onnxscript/onnxrewriter/optimizer/constant_folding_test.py b/onnxscript/onnxrewriter/optimizer/constant_folding_test.py new file mode 100644 index 0000000000..031ebd72b8 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/constant_folding_test.py @@ -0,0 +1,452 @@ +import unittest + +import onnx +import pytest + +from onnxrewriter import optimizer + + +class FoldConstantsTest(unittest.TestCase): + def test_fold_add(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_cast_like(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + two_float = CastLike(two, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_shape(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + two_float = CastLike(rank, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_shape_slice(self): + model = onnx.parser.parse_model( + """ + + agraph (float[M, N, 16, 16] x) => (float[M, N, 16, 16] z) { + shape = Shape (x) + two = Size(shape) + two_float = CastLike(two, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_if_cond(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + zero = Constant () + zero_cast = CastLike (zero, rank) + is_scalar = Equal(zero_cast, rank) + z = If (is_scalar) < + then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, + else_branch = else_graph () => (else_z) { else_z = Mul (x, x) } + > + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].output[0], "z") + self.assertEqual(optimized.graph.node[0].op_type, "Mul") + + def test_fold_inside_if_branch(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { + two = Constant () + z = If (cond) < + then_branch = then_graph () => (then_z) { + three = Constant () + temp = Add (two, three) + then_z = Mul (temp, x) + }, + else_branch = else_graph () => (else_z) { + four = Constant () + temp = Add (two, four) + else_z = Mul (temp, x) + } + > + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 1) + then_graph = onnx.helper.get_node_attr_value( + optimized.graph.node[0], "then_branch" + ) + self.assertEqual(len(then_graph.node), 2) + else_graph = onnx.helper.get_node_attr_value( + optimized.graph.node[0], "else_branch" + ) + self.assertEqual(len(else_graph.node), 2) + + def test_fold_if_propagate(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + zero = Constant () + two = Constant () + zero_cast = CastLike (zero, rank) + is_scalar = Equal(zero_cast, rank) + m = If (is_scalar) < + then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, + else_branch = else_graph () => (else_z) { else_z = Mul (two, two) } + > + m_square = Mul (m, m) + z = Mul (x, m_square) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + print(onnx.printer.to_text(optimized)) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "m_square") + self.assertEqual(optimized.graph.node[0].op_type, "Constant") + + def test_fold_redundant_cast(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + x_cast = CastLike(x, two) + z = Mul(x_cast, two) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + + def test_fold_redundant_cast2(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + z = CastLike(x, two) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].op_type, "Identity") + self.assertEqual(optimized.graph.node[0].output[0], "z") + self.assertEqual(optimized.graph.node[0].input[0], "x") + + @pytest.mark.skip(reason="Feature removed to catch errors early") + def test_fold_undefined_vars(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + four = Add(two, two) + y = Shape(t1) + w = CastLike(x, t2) + w2 = CastLike(t3, t4) + w3 = Size(t5) + z = Sum (four, y, w, w2, w3) + } + """ + ) + # No optimizations expected. Just make sure it doesn't crash. + optimized = optimizer.optimize( + model, num_iterations=1, onnx_shape_inference=False + ) + self.assertEqual(len(optimized.graph.node), 6) + + def test_shape_inference(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[64] x) => (int64[N] z) { + one = Constant () + cond = Equal(one, one) + temp = If (cond) < + then_branch = then_graph () => (then_z) { + shape1 = Constant () + then_z = Reshape(x, shape1) + }, + else_branch = else_graph () => (else_z) { + shape2 = Constant () + else_z = Reshape(x, shape2) + }> + shape = Shape(temp) # shape = [8, 8] or [64], but [8, 8] after constant propagation + rank = Size(shape) # rank = 2 or 1, but 2 after constant propagation + C = Add (rank, rank) + z = Mul(x, C) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + print(onnx.printer.to_text(optimized)) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "C") + + def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => ( return_val) { + int64_128 = Constant () + splits = SplitToSequence (x, int64_128) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + int64_3 = Constant () + split_3 = SequenceAt (splits, int64_3) + return_val = Concat (split_0, split_1, split_2, split_3) +} + """ + ) + + # TODO: There is an unrelated limitation that `symbolic_value` is not + # utilized when the value is only referenced by graph output. + # E.g., the following test model will not have this optimization + # applied. + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => ( split_0, split_1, split_2, split_3) { + int64_128 = Constant () + splits = SplitToSequence (x, int64_128) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + int64_3 = Constant () + split_3 = SequenceAt (splits, int64_3) +} + """ + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[-2].output), 4) + self.assertEqual(optimized.graph.node[-2].op_type, "Split") + + def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + return_val = Concat (split_0, split_1, split_2) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 3) + self.assertEqual(len(optimized.graph.node[-2].output), 3) + self.assertEqual(optimized.graph.node[-2].op_type, "Split") + + def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + return_val = Concat (split_0, split_1, split_2) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 7) + self.assertEqual(len(optimized.graph.node[1].output), 3) + self.assertEqual(optimized.graph.node[1].op_type, "Split") + self.assertEqual( + len([n for n in optimized.graph.node if n.op_type == "Squeeze"]), 3 + ) + + def test_static_split_to_sequence_with_uneven_split(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], + producer_name: "pytorch", + producer_version: "2.2.0" +> +main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) + < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> +{ + _val_1 = Constant () + _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) + _val_3 = Constant () + getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) + _val_5 = Constant () + getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) + return_val = Concat (getitem_1, getitem) +} +< + domain: "pkg.onnxscript.torch_lib", + opset_import: ["" : 18] +> +aten_split (self, split_size) => (return_val) +{ + return_val = SplitToSequence (self, split_size) +} +< + domain: "pkg.onnxscript.torch_lib", + opset_import: ["" : 18] +> +aten_getitem (self, i) => (return_val) +{ + return_val = SequenceAt (self, i) +} +< + domain: "pkg.onnxscript.torch_lib.common", + opset_import: ["" : 18] +> +Rank (input) => (return_val) +{ + tmp = Shape (input) + return_val = Size (tmp) +} +< + domain: "pkg.onnxscript.torch_lib.common", + opset_import: ["" : 18] +> +IsScalar (input) => (return_val) +{ + tmp = Shape (input) + tmp_0 = Size (tmp) + tmp_1 = Constant () + return_val = Equal (tmp_0, tmp_1) +} + """ + ) + optimized = optimizer.optimize(model, onnx_shape_inference=False) + + print(onnx.printer.to_text(optimized)) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[0].output), 2) + self.assertEqual(optimized.graph.node[0].op_type, "Split") + + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + return_val = ConcatFromSequence (splits) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 3) + self.assertEqual(optimized.graph.node[2].op_type, "Concat") + onnx.checker.check_model(optimized) + + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + return_val = ConcatFromSequence (splits) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 7) + self.assertEqual(optimized.graph.node[6].op_type, "Concat") + onnx.checker.check_model(optimized) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/optimizer/copy_propagation.py b/onnxscript/onnxrewriter/optimizer/copy_propagation.py new file mode 100644 index 0000000000..6b73303ade --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/copy_propagation.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Any + +import onnx + +import onnxrewriter.optimizer.remove_unused +from onnxrewriter.ir import visitor +from onnxrewriter.utils.utils import is_onnx_op + + +class CopyPropagator(visitor.ProtoVisitor): + def __init__(self): + super().__init__() + + def visit_node(self, node: onnx.NodeProto) -> None: + super().visit_node(node) + for i in range(len(node.input)): + input = self.get_input(node, i) + if input is not None and input.is_copy(): + node.input[i] = input.symbolic_value # type: ignore[assignment] + + if is_onnx_op(node, "Identity"): + input = self.get_input(node, 0) + output = self.get_output(node, 0) + if input is not None and output is not None: + output.symbolic_value = input.name + + +# TODO: "Z = Identity(x)" where Z is a graph-output cannot be handled by this optimization, +# and requires some extension. (Eg., we could rename graph-output to be Z or we can try to +# rename x to be Z.) + + +def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any: + matching = [x for x in node.attribute if x.name == attr_name] + if len(matching) > 1: + raise ValueError(f"Node has multiple attributes with name {attr_name}") + if len(matching) < 1: + return default + return onnx.helper.get_attribute_value(matching[0]) + + +class SymbolicEvaluator(CopyPropagator): + def __init__(self): + super().__init__() + + def visit_node(self, node: onnx.NodeProto) -> None: + super().visit_node(node) + + if is_onnx_op(node, "SequenceConstruct"): + output = self.get_output(node, 0) + if output is not None: + output.symbolic_value = list(node.input) + + if is_onnx_op(node, "ConcatFromSequence"): + input = self.get_input(node, 0) + new_axis = get_node_attr_value(node, "new_axis", 0) + if ( + input is not None + and isinstance(input.symbolic_value, list) + and new_axis == 0 + ): + node.op_type = "Concat" + node.input[:] = input.symbolic_value + for i in range(len(node.attribute)): + if node.attribute[i].name == "new_axis": + del node.attribute[i] + break + + # TODO: handle SequenceEmpty, SequenceAt, etc. + + +def do_copy_propagation(model: onnx.ModelProto, *, remove_unused: bool = True) -> None: + transformer = CopyPropagator() + transformer.visit_model(model) + if remove_unused: + onnxrewriter.optimizer.remove_unused_nodes(model) + + +def do_sequence_simplification( + model: onnx.ModelProto, *, remove_unused: bool = True +) -> None: + transformer = SymbolicEvaluator() + transformer.visit_model(model) + if remove_unused: + onnxrewriter.optimizer.remove_unused_nodes(model) diff --git a/onnxscript/onnxrewriter/optimizer/copy_propagation_test.py b/onnxscript/onnxrewriter/optimizer/copy_propagation_test.py new file mode 100644 index 0000000000..4bf2a93a94 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/copy_propagation_test.py @@ -0,0 +1,49 @@ +import unittest + +import onnx + +from onnxrewriter import optimizer + + +class RemoveUnusedTest(unittest.TestCase): + def test_simple_identity_removal(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + t = Identity(x) + t2 = Identity(t) + z = Identity(t2) + } + """ + ) + optimizer.do_copy_propagation(model) + self.assertEqual(len(model.graph.node), 1) + + def test_subgraph_identity_removal(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, bool cond) => (float[N] z) { + t = Identity(x) + t2 = Identity(t) + t3 = If (cond) < + then_branch = then_graph() => (t4) { + t5 = Identity(t2) + t4 = Identity(t5) + }, + else_branch = else__graph() => (t6) { + t7 = Identity(t) + t6 = Identity(t7) + } + > + z = Identity(t3) + } + """ + ) + optimizer.do_copy_propagation(model) + self.assertEqual(len(model.graph.node), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/optimizer/evaluator.py b/onnxscript/onnxrewriter/optimizer/evaluator.py new file mode 100644 index 0000000000..db4f014dd4 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/evaluator.py @@ -0,0 +1,442 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- + +from __future__ import annotations + +import dataclasses +import logging +import math +from typing import Any, Callable, Protocol, Sequence, Union + +import numpy as np +import onnx +import onnx.reference.ops + +from onnxrewriter import ir +from onnxrewriter.utils.utils import ( + get_node_attr_value, +) + +logger = logging.getLogger(__name__) + +# "Standard" evaluators are used to perform constant-folding. +# The API below works only for non-control-flow ops (ops without any graph-attributes). +# This currently used ONNX's reference implementation. But we could also +# use ORT's implementation if we want to. + + +class ReferenceEvaluator: + def get_evaluator(self, domain: str, op: str, version: int) -> callable | None: + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + return op_impl_class.eval # noqa: TRY300 + except Exception: # noqa: BLE001 + return None + + def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: + logger.debug("Evaluating %s::%s", domain, op) + evaluator = self.get_evaluator(domain, op, version) + if evaluator is None: + return None + return evaluator(*args, **kwargs) + + +reference_evaluator = ReferenceEvaluator() + +# The "partial evaluators" below are non-standard evaluators. They are used to perform +# partial evaluation and/or static program analysis (abstract interpretation). + + +class IRContext(Protocol): + """A class that represents the context for partial evaluation. + + This is a placeholder, subject to simplification when a proper IR is defined. + """ + + def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... + + def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... + + def input_const_value( + self, node: onnx.NodeProto, index: int + ) -> ir.ConcreteValue: ... + + def input_shape( + self, node: onnx.NodeProto, index: int + ) -> onnx.TensorShapeProto | None: ... + + def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: ... + + def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: ... + + def lookup_version(self, domain: str) -> int: ... + + def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict: ... + + def new_constant( + self, name: str, value: Any + ) -> Sequence[onnx.NodeProto] | None: ... + + +# A partial-evaluator function takes an IRContext and a node, and returns a list of +# replacement nodes or None (if no replacement is needed). We return None instead +# of [input node] so the caller is aware that the node is not replaced. If the node +# is replaced, the caller will recursively visit the replacement nodes to process them. + +PartialEvaluatorFunction = Union[ + Callable[[IRContext, onnx.NodeProto], Sequence[onnx.NodeProto]], None +] + + +@dataclasses.dataclass +class PartialEvaluator: + """A class that represents a partial-evaluator for a particular op. + + It is applicable for a specific version range (min_version, max_version) of the op. + The min_version and max_version can be None, indicating that there is no version + constraint in that direction. + """ + + min_version: int | None + max_version: int | None + function: PartialEvaluatorFunction + + def valid_for(self, version: int) -> bool: + """Returns True if this evaluator is applicable for the given version.""" + return (self.min_version is None or version >= self.min_version) and ( + self.max_version is None or version <= self.max_version + ) + + +class PartialEvaluatorRegistry: + """A class that maintains a registry of evaluators for ops.""" + + def __init__(self): + self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} + + def lookup_evaluators(self, domain: str, opname: str, version: int): + evaluator_list = self.op_evaluators.get((domain, opname), []) + return [ + evaluator.function + for evaluator in evaluator_list + if evaluator.valid_for(version) + ] + + def register(self, opname: str, domain: str = "", version=None): + if (domain, opname) not in self.op_evaluators: + evaluator_list = [] + self.op_evaluators[(domain, opname)] = evaluator_list + else: + evaluator_list = self.op_evaluators[(domain, opname)] + if version is None: + min_version = None + max_version = None + elif isinstance(version, int): + min_version = version + max_version = version + elif isinstance(version, tuple): + min_version, max_version = version + + def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: + evaluator_list.append(PartialEvaluator(min_version, max_version, function)) + return function + + return decorator + + +registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() + +register = registry.register + + +def get_bool_value(val) -> bool | None: + if isinstance(val, bool): + return val + if isinstance(val, np.bool_): + return bool(val) + if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: + return val.item(0) + return None + + +def get_size_info(type: onnx.TypeProto) -> np.ndarray | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): + size = 1 + for d in type.tensor_type.shape.dim: + size *= d.dim_value + return np.array(size, dtype=np.int64) + return None + + +def get_dim_info(type: onnx.TypeProto, dim: int) -> int | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + rank = len(type.tensor_type.shape.dim) + dim = dim if dim >= 0 else dim + rank + if dim < 0 or dim >= rank: + return None + if type.tensor_type.shape.dim[dim].HasField("dim_value"): + return type.tensor_type.shape.dim[dim].dim_value + return None + + +@register("Cast") +def cast(context: IRContext, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + if context.input_shape(node, 0) is not None: + output_value = context.get_output(node, 0) + output_value.type = onnx.TypeProto() + output_value.type.CopyFrom(context.input_type(node, 0)) + output_value.type.tensor_type.elem_type = node.attribute[0].i + return None + + +@register("CastLike") +def cast_like(context: IRContext, node: onnx.NodeProto): + source_element_type = context.input_element_type(node, 0) + target_element_type = context.input_element_type(node, 1) + + if target_element_type is None: + return None + if source_element_type == target_element_type: + node.op_type = "Identity" + del node.input[1] + return [node] + + node.op_type = "Cast" + del node.input[1] + del node.attribute[:] + node.attribute.append(onnx.helper.make_attribute("to", target_element_type)) + return [node] + + +@register("Shape") +def shape(context: IRContext, node: onnx.NodeProto): + shape = context.input_shape(node, 0) + if shape is None: + return None + start = get_node_attr_value(node, "start", 0) + end = get_node_attr_value(node, "end", None) + shape_slice = shape.dim[start:end] + if all(d.HasField("dim_value") for d in shape_slice): + return np.array([d.dim_value for d in shape_slice], dtype=np.int64) + return None + + +@register("Size") +def size(context: IRContext, node: onnx.NodeProto): + type = context.input_type(node, 0) + size = get_size_info(type) if type is not None else None + return size + + +@register("If") +def if_op(context: IRContext, node: onnx.NodeProto): + cond = context.input_const_value(node, 0) + if cond is ir.NotConstant: + # Visitor will recursively visit subgraphs to constant-fold them. + return None + cond = get_bool_value(cond) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph = onnx.helper.get_node_attr_value(node, branch) + + formal_outs = list(graph.output) + actual_outs = node.output + renamings = { + formal.name: actual + for formal, actual in zip(formal_outs, actual_outs) + if actual != "" + } + # TODO: Extend renaming to intermediate values. + + def rename(name): + return renamings.get(name, name) + + for sub_node in graph.node: + # TODO: handle renaming inside subgraphs in nodes + sub_node.input[:] = [rename(name) for name in sub_node.input] + sub_node.output[:] = [rename(name) for name in sub_node.output] + # Avoid name collision. + sub_node.name = f"{node.name}_{sub_node.name}" + + # TODO: we should handle initializers as well! + return list(graph.node) + return None + + +@register("Identity") +def identity(context: IRContext, node: onnx.NodeProto): + input = context.get_input(node, 0) + output = context.get_output(node, 0) + if input is not None and output is not None: + output.symbolic_value = input.name + + +@register("SequenceConstruct") +def sequence_construct( + context: IRContext, node: onnx.NodeProto +) -> Sequence[onnx.NodeProto] | None: + output = context.get_output(node, 0) + if output is not None: + output.symbolic_value = list(node.input) + return None + + +@register("ConcatFromSequence") +def concat_from_sequence( + context: IRContext, node: onnx.NodeProto +) -> Sequence[onnx.NodeProto] | None: + input = context.get_input(node, 0) + attrs = context.convert_attributes(node.attribute) + new_axis = attrs.get("new_axis", 0) + if input is not None and isinstance(input.symbolic_value, list): + if new_axis == 0: + node.op_type = "Concat" + node.input[:] = input.symbolic_value + logger.debug("ConcatFromSequence => Concat: %s", node.input) + for i in range(len(node.attribute)): + if node.attribute[i].name == "new_axis": + del node.attribute[i] + return [node] + return [node] + if new_axis == 1: + # Unsqueeze the inputs with concat axis if new_axis is 1 + axis = attrs.get("axis", None) + assert axis is not None + output = context.get_output(node, 0) + axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] + unsqueeze_nodes = [] + for node_input in input.symbolic_value: + unsqueeze_node = onnx.helper.make_node( + "Unsqueeze", + [node_input, axis_node.output[0]], + [f"{node_input}_unsqueeze"], + ) + unsqueeze_nodes.append(unsqueeze_node) + unsqueeze_outputs = [n.output[0] for n in unsqueeze_nodes] + unsqueeze_nodes = [axis_node, *unsqueeze_nodes] + + # Send unsqueezed outputs to Concat + node.input[:] = unsqueeze_outputs + node.op_type = "Concat" + logger.debug( + "ConcatFromSequence => UnSqueeze %s + Concat %s", + unsqueeze_outputs, + node.input, + ) + for i in range(len(node.attribute)): + if node.attribute[i].name == "new_axis": + del node.attribute[i] + return [*unsqueeze_nodes, node] + return None + + +@register("SplitToSequence") +def split_to_sequence( + context: IRContext, node: onnx.NodeProto +) -> Sequence[onnx.NodeProto] | None: + """Rewriting pattern. + + From + + splits = onnx::SplitToSequence(input, split, axis=axis) + + to + + split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + or + + split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + where number of output tensors in `splits` is statically known. + onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. + This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. + """ + input = context.get_input(node, 0) + split = context.get_input(node, 1) + attrs = context.convert_attributes(node.attribute) + output = context.get_output(node, 0) + + if input is None or split is None or output is None: + return None + + axis = attrs.get("axis", 0) + if input.type is None: + return None + split_dimension_size = get_dim_info(input.type, axis) + if split_dimension_size is None: + return None + + split_value = split.value + if split_value is None or split_value is ir.NotConstant: + return None + assert isinstance(split_value, np.ndarray) + + if split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_node = onnx.helper.make_node( + "Split", + [input.name], + split_outputs, + axis=axis, + num_outputs=num_outputs, + ) + else: + # split into 'size(split)' chunks + num_outputs = split_value.size + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_node = onnx.helper.make_node( + "Split", + [input.name, split.name], + split_outputs, + axis=axis, + ) + + keepdims = attrs.get("keepdims", 1) + squeeze_nodes = [] + if keepdims == 0: + # squeeze the split dimension if keepdims is 0 + axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] + for i in range(num_outputs): + squeeze_node = onnx.helper.make_node( + "Squeeze", + [split_outputs[i], axis_node.output[0]], + [f"{split_outputs[i]}_squeeze"], + ) + squeeze_nodes.append(squeeze_node) + split_outputs = [n.output[0] for n in squeeze_nodes] + squeeze_nodes = [axis_node, *squeeze_nodes] + + node.op_type = "SequenceConstruct" + node.input[:] = split_outputs + del node.attribute[:] + logger.debug( + "SplitToSequence => Split %s + SequenceConstruct %s", + split_node.input, + node.input, + ) + return [split_node, *squeeze_nodes, node] + + +@register("SequenceAt") +def sequence_at( + context: IRContext, node: onnx.NodeProto +) -> Sequence[onnx.NodeProto] | None: + input = context.get_input(node, 0) + position = context.get_input(node, 1) + output = context.get_output(node, 0) + if input is not None and position is not None: + input_vals = input.symbolic_value + position_val = position.value + if isinstance(input_vals, list) and position_val is not None: + output.symbolic_value = input_vals[position_val] + logger.debug("SquenceAt %s => %s", input, output.symbolic_value) + return None diff --git a/onnxscript/onnxrewriter/optimizer/fold_constants_v0.py b/onnxscript/onnxrewriter/optimizer/fold_constants_v0.py new file mode 100644 index 0000000000..2e8029797b --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/fold_constants_v0.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from typing import Any, Sequence + +import numpy as np +import onnx +import onnx.reference.ops + +# Excluded ops include +# * Random ops, which are not deterministic +# * Control flow ops + +excluded_ops = frozenset( + { + "RandomUniform", + "RandomNormal", + "RandomUniformLike", + "RandomNormalLike", + "Multinomial", + "If", + "Loop", + "Scan", + "SequenceMap", + } +) + +onnx_domain = frozenset({"", "onnx.ai"}) + + +def get_evaluator(domain: str, op: str, version: int) -> callable | None: + if op in excluded_ops and domain in onnx_domain: + return None + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + except Exception: # noqa: BLE001 + return None + else: + return op_impl_class.eval + + +def convert_attributes(attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]: + return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} + + +def is_control_flow_op(node: onnx.NodeProto) -> bool: + return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute) + + +def is_constant_op(node: onnx.NodeProto) -> bool: + return node.op_type == "Constant" and node.domain == "" + + +def get_bool_value(val) -> bool | None: + if isinstance(val, bool): + return val + if isinstance(val, np.bool_): + return bool(val) + if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: + return val.item(0) + return None + + +def get_shape_info(type: onnx.TypeProto) -> tuple[int, ...] | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): + return np.array( + [d.dim_value for d in type.tensor_type.shape.dim], dtype=np.int64 + ) + return None + + +def get_element_type(type: onnx.TypeProto) -> int | None: + if type.HasField("tensor_type"): + return type.tensor_type.elem_type + return None + + +class State: + def __init__(self, default_value) -> None: + self.scopes = [{}] + self.default_value = default_value + + def lookup(self, name: str) -> Any: + for scope in reversed(self.scopes): + if name in scope: + return scope[name] + return self.default_value + + def bind(self, name: str, value: Any) -> None: + self.scopes[-1][name] = value + + def enter_scope(self) -> None: + self.scopes.append({}) + + def exit_scope(self) -> None: + self.scopes.pop() + + +def is_onnx_op(node: onnx.NodeProto, op: str) -> bool: + return (node.op_type == op) and (node.domain in onnx_domain) + + +def matches(node: onnx.NodeProto, op: str, *arg_predicates) -> bool: + if node.op_type != op or node.domain != "": + return False + if len(node.input) < len(arg_predicates): + return False + return all(pred(input) for pred, input in zip(arg_predicates, node.input)) + + +def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto: + type = onnx.TypeProto() + type.tensor_type.elem_type = initializer.data_type + dims = type.tensor_type.shape.dim + for dim in initializer.dims: + dims.add().dim_value = dim + return type + + +def fold_constants(model: onnx.ModelProto): + not_constant = object() + var_info = State(default_value=not_constant) + type_info = State(default_value=None) + counts = {} + sizes = {} + + def add_count(op: str, size: int = 1): + counts[op] = counts.get(op, 0) + 1 + sizes[op] = sizes.get(op, 0) + size + + def new_constant(name, value): + var_info.bind(name, value) + tensor = onnx.numpy_helper.from_array(value, name=name) + node = onnx.helper.make_node( + "Constant", inputs=[], outputs=[name], value=tensor + ) + return node + + def lookup_version(domain: str, op: str) -> int: # noqa: ARG001 + for opset in model.opset_import: + if opset.domain == domain: + return opset.version + return 1 # TODO + + def transform_node(node: onnx.NodeProto): + if is_onnx_op(node, "Transpose"): + return [node] + if is_onnx_op(node, "CastLike"): + value = ( + var_info.lookup(node.input[0]) if len(node.input) > 0 else not_constant + ) + if value is not_constant: + return [node] + type = type_info.lookup(node.input[1]) if len(node.input) > 1 else None + element_type = get_element_type(type) if type is not None else None + if element_type is None: + return [node] + evaluator = get_evaluator("", "Cast", lookup_version("", "Cast")) + if evaluator is None: + return [node] + cast_value = evaluator(value, to=element_type) + add_count("CastLike", cast_value.size) + return [new_constant(node.output[0], cast_value)] + if is_onnx_op(node, "Shape"): + type = type_info.lookup(node.input[0]) if len(node.input) > 0 else None + shape = get_shape_info(type) if type is not None else None + if shape is not None: + add_count("Shape", shape.size) + return [new_constant(node.output[0], shape)] + + if is_onnx_op(node, "If"): + cond = var_info.lookup(node.input[0]) if len(node.input) > 0 else None + cond = get_bool_value(cond) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph = onnx.helper.get_node_attr_value(node, branch) + formal_outs = list(graph.output) + actual_outs = node.output + renamings = { + formal.name: actual + for formal, actual in zip(formal_outs, actual_outs) + if actual != "" + } + + def rename(name): + return renamings.get(name, name) + + for node in graph.node: + node.input[:] = [rename(name) for name in node.input] + node.output[:] = [rename(name) for name in node.output] + transform_graph(graph) + add_count("If") + return list(graph.node) + + if is_control_flow_op(node): + for attr in node.attribute: + if attr.HasField("g"): + transform_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + transform_graph(graph) + return [node] + + domain = node.domain + op = node.op_type + version = lookup_version(domain, op) + inputs = [] + for x in node.input: + if x == "": + inputs.append(None) + else: + v = var_info.lookup(x) + if v is not_constant: + return [node] + inputs.append(v) + evaluator = get_evaluator(domain, op, version) + if evaluator is None: + return [node] + attrs = convert_attributes(node.attribute) + outputs = evaluator(*inputs, **attrs) + if len(node.output) == 1 and not isinstance(outputs, tuple): + replacement = new_constant(node.output[0], outputs) + if is_constant_op(node): + return [node] + add_count(op, outputs.size) + return [replacement] + else: + add_count(op) + return [ + new_constant(output, outputs[i]) for i, output in enumerate(node.output) + ] + + def transform_graph(graph: onnx.GraphProto): + var_info.enter_scope() + type_info.enter_scope() + for initializer in graph.initializer: + array = onnx.numpy_helper.to_array(initializer) + var_info.bind(initializer.name, array) + type_info.bind(initializer.name, get_initializer_type(initializer)) + for input in graph.input: + var_info.bind(input.name, not_constant) + type_info.bind(input.name, input.type) + for valueinfo in graph.value_info: + type_info.bind(valueinfo.name, valueinfo.type) + + replacement = [transform_node(node) for node in graph.node] + flattened = [node for nodes in replacement for node in nodes] + del graph.node[:] + graph.node.extend(flattened) + var_info.exit_scope() + type_info.exit_scope() + + transform_graph(model.graph) + for op in counts: + print(f"Constant-folded '{op}' {counts[op]} times, with {sizes[op]} size.") diff --git a/onnxscript/onnxrewriter/optimizer/function_folding_test.py b/onnxscript/onnxrewriter/optimizer/function_folding_test.py new file mode 100644 index 0000000000..53b537ba01 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/function_folding_test.py @@ -0,0 +1,162 @@ +import unittest + +import onnx + +from onnxrewriter import optimizer + + +class FunctionFoldingTest(unittest.TestCase): + def test_identity(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1, bool cond1) => (float[N] z1) { + z1 = local.fun1(x1, cond1) + } + + fun1 (x, cond) => (z) { + t = Identity(x) + t2 = Identity(t) + t3 = If (cond) < + then_branch = then_graph() => (t4) { + t5 = Identity(t2) + t4 = Identity(t5) + }, + else_branch = else__graph() => (t6) { + t7 = Identity(t) + t6 = Identity(t7) + } + > + t4 = Add(t3, t3) + z = Identity(t4) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + self.assertEqual(len(optimized.functions), 0) + self.assertEqual(len(optimized.graph.node), 2) + + def test_sequence_concat(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1) => (float[M] z1) { + z1 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + function_node = optimized.functions[0].node + self.assertEqual(len(function_node), 3) + self.assertEqual(function_node[2].op_type, "Concat") + + def test_single_user_function_is_modified_inplace_after_folding(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1) => (float[M] z1) { + z1 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + self.assertEqual(optimized.functions[0].name, "fun1") + + def test_multi_users_function_is_not_modified_inplace_after_folding(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1) => (float[M] z1, float[M] z2) { + z1 = local.fun1(x1) + z2 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + self.assertEqual(len(optimized.functions), 2) + self.assertNotEqual(optimized.functions[0].name, "fun1") + self.assertNotEqual(optimized.functions[1].name, "fun1") + + def test_fold_nested_if_function_succeeds(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 9, + opset_import: ["this" : 1, "" : 21] +> +func (float[1,512] x, float[1,512] y) => ( out) { + out = this.foldable_func (x, y) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable_func (x, y) => (z_6) +{ + cond = Constant () + z_6 = If (cond) ( z_2) { + cond_0 = Not (cond) + z_2 = If (cond_0) ( z) { + z = Add (x, x) + }, else_branch: graph = elseGraph_5 () => ( z_1) { + z_1 = Identity (x) + }> + }, else_branch: graph = elseGraph_4 () => ( z_5) { + z_5 = If (cond) ( z_3) { + z_3 = Add (y, y) + }, else_branch: graph = elseGraph_10 () => ( z_4) { + z_4 = Add (x, y) + }> + }> +} + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + ) + + self.assertEqual(len(optimized.functions), 0) + self.assertEqual(len(optimized.graph.node), 1) + self.assertNotIn("If", {n.op_type for n in optimized.graph.node}) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/optimizer/remove_unused.py b/onnxscript/onnxrewriter/optimizer/remove_unused.py new file mode 100644 index 0000000000..ea9bf88c2b --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/remove_unused.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import logging +from typing import Sequence + +import onnx +from google.protobuf.internal.containers import ( # type: ignore + RepeatedCompositeFieldContainer, +) + +logger = logging.getLogger(__name__) + + +def remove_unused_optional_outputs( + n: onnx.NodeProto, used: set, opset_import: Sequence[onnx.OperatorSetIdProto] +) -> None: + try: + if n.domain not in {"", "onnx.ai"}: + return + onnx_opset_version = 1 + for opset in opset_import: + if opset.domain == n.domain: + onnx_opset_version = opset.version + op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain) + except Exception: # noqa: BLE001 + return + # TODO: If current node is a BatchNormalization node, + # based on training_mode atrribute, number of optional outputs and + # how they are handled varies, handle both training_modes + if n.op_type == "BatchNormalization": + return + optional_info = [] + for o in op_schema.outputs: + # Current ops do not have optional outputs if they have variable number of outputs + if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic: + return + optional_info.append( + o.option == onnx.defs.OpSchema.FormalParameterOption.Optional + ) + # If no optional outputs in spec, skip delete operations + if len([o == 1 for o in optional_info]) == 0: + return + + for i, out in enumerate(n.output): + if out not in used and optional_info[i] is True: + n.output[i] = "" + # Only delete trailing unused optional outputs + for o in n.output[::-1]: # type: ignore[assignment] + if o == "": + n.output.pop() + else: + return + + +def compute_used_in_node(n: onnx.NodeProto) -> set[str]: + used = {n for n in n.input if n != ""} + for attr in n.attribute: + if attr.HasField("g"): + used |= compute_used_in_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + used |= compute_used_in_graph(graph) + return used + + +def compute_used_in_graph(g: onnx.GraphProto) -> set[str]: + used = set() + for n in g.node: + used |= compute_used_in_node(n) + return used + + +def process_nodes( + nodes: RepeatedCompositeFieldContainer[onnx.NodeProto], + used: set, + opset_import: Sequence[onnx.OperatorSetIdProto], +) -> int: + count = 0 + i = len(nodes) - 1 + while i >= 0: + node = nodes[i] + remove_unused_optional_outputs(node, used, opset_import) + used_outputs = [x for x in node.output if x in used] + if not used_outputs: + del nodes[i] + count += 1 + i -= 1 + continue + for attr in node.attribute: + if attr.HasField("g"): + process_graph(attr.g, opset_import) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + process_graph(graph, opset_import) + used |= compute_used_in_node(node) + i -= 1 + return count + + +def process_graph( + graph: onnx.GraphProto, opset_import: Sequence[onnx.OperatorSetIdProto] +) -> int: + used = {output.name for output in graph.output} + + count = process_nodes(graph.node, used, opset_import) + + for i in range(len(graph.initializer) - 1, -1, -1): + if graph.initializer[i].name not in used: + del graph.initializer[i] + count += 1 + + return count + + +def process_function( + function: onnx.FunctionProto, opset_import: Sequence[onnx.OperatorSetIdProto] +) -> int: + used = set(function.output) + + return process_nodes(function.node, used, opset_import) + + +def remove_unused_nodes(model: onnx.ModelProto) -> None: + """Removes unused nodes from the model.""" + count = process_graph(model.graph, model.opset_import) + for function in model.functions: + count += process_function(function, model.opset_import) + + logger.info("Removed %s unused nodes", count) diff --git a/onnxscript/onnxrewriter/optimizer/remove_unused_function.py b/onnxscript/onnxrewriter/optimizer/remove_unused_function.py new file mode 100644 index 0000000000..573dfaa8b1 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/remove_unused_function.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import logging + +import onnx +from google.protobuf.internal.containers import ( # type: ignore + RepeatedCompositeFieldContainer, +) + +logger = logging.getLogger(__name__) + + +class UnusedFunctionRemover: + def compute_used_in_node(self, n: onnx.NodeProto) -> set[tuple[str, str]]: + used = {(n.domain, n.op_type)} + for attr in n.attribute: + if attr.HasField("g"): + used |= self.process_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + used |= self.process_graph(graph) + if (n.domain, n.op_type) in self._functions: + function = self._functions[(n.domain, n.op_type)] + used |= self.process_function(function) + return used + + def process_nodes( + self, nodes: RepeatedCompositeFieldContainer[onnx.NodeProto] + ) -> set[tuple[str, str]]: + used = set() + for node in nodes: + used |= self.compute_used_in_node(node) + return used + + def process_graph(self, graph: onnx.GraphProto) -> set[tuple[str, str]]: + return self.process_nodes(graph.node) + + def process_function(self, function: onnx.FunctionProto) -> set[tuple[str, str]]: + return self.process_nodes(function.node) + + def process_model(self, model: onnx.ModelProto) -> None: + self._functions = {(f.domain, f.name): f for f in model.functions} + used = self.process_graph(model.graph) + count = 0 + logger.debug("Used function protos: %s", used) + for i in range(len(model.functions) - 1, -1, -1): + if (model.functions[i].domain, model.functions[i].name) not in used: + del model.functions[i] + count += 1 + logger.info("Removed %s unused function protos", count) + logger.debug("Function protos left: %s", [f.name for f in model.functions]) + + +def remove_unused_functions(model: onnx.ModelProto) -> None: + """Removes unused function protos from the model.""" + UnusedFunctionRemover().process_model(model) diff --git a/onnxscript/onnxrewriter/optimizer/remove_unused_test.py b/onnxscript/onnxrewriter/optimizer/remove_unused_test.py new file mode 100644 index 0000000000..06057f6e9a --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/remove_unused_test.py @@ -0,0 +1,173 @@ +import unittest + +import onnx + +from onnxrewriter import optimizer + + +class RemoveUnusedTest(unittest.TestCase): + def test_remove_unused_nodes(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + + def test_remove_unused_initializers(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + self.assertEqual(len(model.graph.initializer), 1) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.initializer), 0) + + def test_partially_used_nodes(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[M] z) { + w1, w2, w3 = Split (x) + z = Mul(w3, w3) + } + """ + ) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 2) + self.assertEqual(model.graph.node[0].op_type, "Split") + + def test_remove_unused_optional_outputs_maxpool(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) { + z, indices = MaxPool (x) + } + """ + ) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 2) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 1) + + def test_remove_unused_optional_outputs_dropout_in_function(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) + { + z = pkg.custom.afunction (x) + } + + afunction (x) => (z) + { + z, indices = MaxPool (x) + } + """ + ) + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.functions[0].node), 1) + self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") + self.assertEqual(len(model.functions[0].node[0].output), 2) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.functions[0].node), 1) + self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") + self.assertEqual(len(model.functions[0].node[0].output), 1) + + def test_remove_used_optional_outputs_maxpool(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] y, float[1, 1, 5, 5] z) { + y, z = MaxPool (x) + } + """ + ) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 2) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 2) + + def test_remove_multiple_unused_optional_outputs_layernorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z) { + scale = Constant () + B = Constant () + z, mean, InvStdDev = LayerNormalization(x, scale, B) + } + """ + ) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 1) + + def test_remove_trailing_unused_optional_outputs_layernorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] mean) { + scale = Constant () + B = Constant () + z, mean, InvStdDev = LayerNormalization(x, scale, B) + } + """ + ) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 2) + + def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] InvStdDev) { + scale = Constant () + B = Constant () + z, mean, InvStdDev = LayerNormalization(x, scale, B) + } + """ + ) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/optimizer/simple_function_folding.py b/onnxscript/onnxrewriter/optimizer/simple_function_folding.py new file mode 100644 index 0000000000..1e25358363 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/simple_function_folding.py @@ -0,0 +1,246 @@ +"""Inlines the function if it only contains very few number of nodes.""" + +from __future__ import annotations + +import logging +from typing import Sequence + +import onnx + +from onnxrewriter import ir +from onnxrewriter.ir import visitor +from onnxrewriter.optimizer import remove_unused + +logger = logging.getLogger(__name__) + + +class FunctionInliner(visitor.FunctionCallsiteProtoTransformer): + counts: dict[ir.FunctionId, int] + + def __init__(self, node_count: int) -> None: + super().__init__() + self._node_count = node_count + + def _gather_function_metadata(self, model: onnx.ModelProto) -> None: + super()._gather_function_metadata(model) + self._function_renamer._postfix = "inlined" + + def visit_model(self, model: onnx.ModelProto) -> None: + self.counts = {} + + super().visit_model(model) + + def should_inline_function(self, function: onnx.FunctionProto) -> bool: + return len(function.node) <= self._node_count + + def process_function_node( + self, node: onnx.NodeProto + ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: + # Recursively process sub nodes first. + function_id = (node.domain, node.op_type, getattr(node, "overload", "")) + function = self._functions[function_id] + replacement, new_function = super().process_function_node(node) + function = new_function if new_function else function + + if self.should_inline_function(function): + self.enter_function_scope(function) + sub_scope = self.exit_function_scope(function) + new_nodes = [] + + formal_outs = function.output + actual_outs = node.output + formal_ins = function.input + actual_ins = node.input + # TODO: Potential collision when actual is "". + # formal.name may collide with existing value names. + input_renamings = dict(zip(formal_ins, actual_ins)) + if len(actual_ins) < len(formal_ins): + input_renamings.update(dict.fromkeys(formal_ins[len(actual_ins) :], "")) + output_renamings = { + formal: actual + for formal, actual in zip(formal_outs, actual_outs) + if actual != "" + } + renamings = {**input_renamings, **output_renamings} + + logger.debug("renamings function %s: %s", function.name, renamings) + + def rename(name: str) -> str: + if name == "": + return name + new_name = renamings.get(name) + if new_name is None: + new_name = f"{node.name}_{name}" + logger.debug("renaming %s to %s", name, new_name) + if (ir_value := sub_scope.lookup(name)) is not None: + if ( + ir_value.tensor_shape_proto() is not None + and ir_value.type is not None + ): + ir_value.name = new_name + self.bind(new_name, ir_value) + return new_name + + ref_attrs = {attr.name: attr for attr in node.attribute} + # logger.debug("inlining simple function %s. Ref attrs: %s", function.name, ref_attrs) + + def fill_in_ref(attr: onnx.AttributeProto) -> onnx.AttributeProto: + if attr.ref_attr_name: + new_attr = onnx.AttributeProto() + new_attr.CopyFrom(ref_attrs[attr.ref_attr_name]) + new_attr.name = attr.name + return new_attr + return attr + + def update_graph_attribute( + attr: onnx.AttributeProto, + ) -> onnx.AttributeProto: + if attr.g: + new_attr = onnx.AttributeProto() + new_attr.CopyFrom(attr) + for node in new_attr.g.node: + node.input[:] = [rename(name) for name in node.input] + node.output[:] = [rename(name) for name in node.output] + new_attrs = [] + for attr in node.attribute: + new_attrs.append(update_attribute(attr)) + del node.attribute[:] + node.attribute.extend(new_attrs) + for vi_proto in new_attr.g.input: + vi_proto.name = rename(vi_proto.name) + for vi_proto in new_attr.g.output: + vi_proto.name = rename(vi_proto.name) + return new_attr + return attr + + def update_attribute(attr: onnx.AttributeProto) -> onnx.AttributeProto: + new_attr = fill_in_ref(attr) + new_attr = update_graph_attribute(new_attr) + return new_attr + + for sub_node in function.node: + # logger.debug("inlining simple function. old node: %s", sub_node) + new_node = onnx.NodeProto() + new_node.CopyFrom(sub_node) + new_node.input[:] = [rename(name) for name in new_node.input] + new_node.output[:] = [rename(name) for name in new_node.output] + del new_node.attribute[:] + for attr in sub_node.attribute: + new_node.attribute.append(update_attribute(attr)) + # Avoid name collision. + new_node.name = f"{node.name}_{new_node.name}" + # logger.debug("inlining simple function. new node: %s", new_node) + new_nodes.append(new_node) + + self.counts.setdefault(function_id, 0) + self.counts[function_id] += 1 + + return new_nodes, None + + return replacement, new_function + + +class SelectedFunctionInliner(FunctionInliner): + def __init__(self, functions_to_inline: Sequence[onnx.FunctionProto]): + super().__init__(node_count=0) # node_count unused. + self._functions_to_inline = functions_to_inline + + def should_inline_function(self, function: onnx.FunctionProto) -> bool: + return function in self._functions_to_inline + + +class FindFunctionWithUnusedOutputsVisitor(visitor.ProtoVisitor): + def __init__(self) -> None: + super().__init__() + self._function_with_unused_outputs: dict[ir.FunctionId, onnx.FunctionProto] = {} + self._functions: dict[ir.FunctionId, onnx.FunctionProto] = {} + self._used_nodes: list[onnx.NodeProto] = [] + + def _find_nodes_with_any_unused_output( + self, nodes: Sequence[onnx.NodeProto], used_values: set[str] + ) -> list[onnx.NodeProto]: + target_nodes = [] + for i in range(len(nodes) - 1, -1, -1): + node = nodes[i] + if any(x not in used_values for x in node.output): + # Any unused output means the node is a target node. + target_nodes.append(node) + if all(x not in used_values for x in node.output): + # All unused output means the node is not used at all. + # Hence do not update used_values with the node's inputs. + continue + used_values |= remove_unused.compute_used_in_node(node) + return target_nodes + + def visit_model(self, model: onnx.ModelProto) -> None: + used_values = {output.name for output in model.graph.output} + target_nodes = self._find_nodes_with_any_unused_output( + model.graph.node, used_values + ) + + for function in model.functions: + self._functions[ + (function.domain, function.name, getattr(function, "overload", "")) + ] = function + used_values = set(function.output) + target_nodes.extend( + self._find_nodes_with_any_unused_output(function.node, used_values) + ) + + for node in target_nodes: + if visitor.is_local_function_node(node, self._functions): + function_id = (node.domain, node.op_type, getattr(node, "overload", "")) + self._function_with_unused_outputs[function_id] = self._functions[ + function_id + ] + + logger.info( + "Found %s function nodes that have unused outputs.", + len(self._function_with_unused_outputs), + ) + for key in self._function_with_unused_outputs: + logger.info("Function node with unused outputs: %s::%s", key[0], key[1]) + + @property + def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto]: + return self._function_with_unused_outputs + + +def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool: + inliner = FunctionInliner(node_count) + inliner.visit_model(model) + logger.info( + "inlined %s simple functions based on node count threshold %s.", + len(inliner.counts), + node_count, + ) + for op in inliner.counts: + logger.info( + "Inlined simple function '%s::%s' %s times.", + op[0], + op[1], + inliner.counts[op], + ) + return inliner.modified + + +def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool: + # TODO: Use onnx.inliner after 1.16. + # This visitor based inliner is used to ensure the function inner value info remains consistent. + visitor = FindFunctionWithUnusedOutputsVisitor() + visitor.visit_model(model) + # FIXME: Fix the type of the argument passed into SelectedFunctionInliner + inliner = SelectedFunctionInliner(visitor.function_with_unused_outputs.values()) # type: ignore[arg-type] + inliner.visit_model(model) + logger.info( + "inlined %s function nodes that have unused outputs.", + len(inliner.counts), + ) + for op in inliner.counts: + logger.info( + "Inlined function '%s::%s' %s times.", + op[0], + op[1], + inliner.counts[op], + ) + return inliner.modified diff --git a/onnxscript/onnxrewriter/optimizer/simple_function_folding_test.py b/onnxscript/onnxrewriter/optimizer/simple_function_folding_test.py new file mode 100644 index 0000000000..b18cd16228 --- /dev/null +++ b/onnxscript/onnxrewriter/optimizer/simple_function_folding_test.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import unittest + +import onnx + +from onnxrewriter.optimizer import remove_unused_function, simple_function_folding + + +class SingleNodeFunctionFoldingTest(unittest.TestCase): + def test_fold_single_node_function(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y) => ( return_val) { + tmp = this.foldable (x) + return_val = Add (tmp, y) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x) => (return_val) +{ + return_val = Identity (x) +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 0) + + def test_fold_single_node_function_ref_attr(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y, z) => ( return_val) { + tmp = this.foldable (x, y) + return_val = Add (tmp, z) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x, y) => (return_val) +{ + return_val = Concat (x, y) +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 0) + self.assertFalse(model.graph.node[0].attribute[0].ref_attr_name) + self.assertEqual(model.graph.node[0].attribute[0].name, "axis") + + def test_fold_single_node_function_nested(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y, z) => ( return_val) { + tmp = this.non_foldable (x, y) + return_val = Add (tmp, z) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x, y) => (return_val) +{ + return_val = Concat (x, y) +} +< + domain: "this", + opset_import: ["this" : 1,"" : 18] +> +non_foldable (x, y) => (return_val) +{ + tmp = this.foldable (x, y) + tmp_0 = this.foldable (x, y) + return_val = Add (tmp, tmp_0) +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 1) + self.assertEqual(model.functions[0].node[0].op_type, "Concat") + self.assertEqual(model.functions[0].node[1].op_type, "Concat") + + def test_fold_single_node_function_create_new_nodes_with_correct_attributes(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 9, + opset_import: ["this" : 1, "" : 21] +> +func (float[1,512] x) => ( a, b, c) { + a = this.prim_cast (x) + b = this.prim_cast (x) + c = this.prim_cast (x) +} +< + domain: "this", + opset_import: ["" : 18] +> +prim_cast (x) => (return_val) +{ + return_val = Cast (x) +} + """ + ) + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + self.assertEqual(len(model.functions), 0) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[0].attribute[0].i, 10) + self.assertEqual(model.graph.node[1].attribute[0].i, 6) + self.assertEqual(model.graph.node[2].attribute[0].i, 7) + + def test_fold_nested_if_function_succeeds(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 9, + opset_import: ["this" : 1, "" : 21] +> +func (float[1,512] x, float[1,512] y) => ( out) { + out = this.foldable_func (x, y) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable_func (x, y) => (z_6) +{ + cond = Constant () + z_6 = If (cond) ( z_2) { + cond_0 = Not (cond) + z_2 = If (cond_0) ( z) { + z = Add (x, x) + }, else_branch: graph = elseGraph_5 () => ( z_1) { + z_1 = Identity (x) + }> + }, else_branch: graph = elseGraph_4 () => ( z_5) { + z_5 = If (cond) ( z_3) { + z_3 = Add (y, y) + }, else_branch: graph = elseGraph_10 () => ( z_4) { + z_4 = Add (x, y) + }> + }> +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 0) + self.assertEqual(len(model.graph.node), 2) + self.assertEqual(model.graph.node[1].op_type, "If") + + def test_fold_function_with_unused_output(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y, z) => ( return_val) { + tmp = this.non_foldable (x, y) + return_val = Add (tmp, z) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x, y) => (return_val, unused, unused1) +{ + return_val = Concat (x, y) + unused = Identity (x) + unused1 = Identity (y) +} +< + domain: "this", + opset_import: ["this" : 1,"" : 18] +> +non_foldable (x, y) => (return_val) +{ + tmp, unused, unused1 = this.foldable (x, y) + tmp_0, unused2, unused3 = this.foldable (x, y) + return_val = Add (tmp, tmp_0) +} + """ + ) + + simple_function_folding.inline_functions_with_unused_outputs(model) + remove_unused_function.remove_unused_functions(model) + self.assertEqual(len(model.functions), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/__init__.py b/onnxscript/onnxrewriter/rewriter/__init__.py new file mode 100644 index 0000000000..3782640617 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/__init__.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Sequence + +__all__ = [ + # Modules + "irbuilder", + "protobuilder", + "function_rule", + "pattern", + # Functions + "rewrite", +] + +import onnx + +from onnxrewriter.ir import irbuilder, protobuilder +from onnxrewriter.rewriter import function_rule, pattern + +PatternRewriteRule = pattern.RewriteRule +FunctionRewriteRule = function_rule.FunctionRewriteRule + + +def rewrite( + model: onnx.ModelProto, + function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (), + pattern_rewrite_rules: Sequence[PatternRewriteRule] = (), +) -> onnx.ModelProto: + if function_rewrite_rules: + model_ir = irbuilder.build_ir(model) + for rule_cls in function_rewrite_rules: + rule_cls().apply_to_model(model_ir) + model = model_ir.original_model_proto + if pattern_rewrite_rules: + model_ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet(pattern_rewrite_rules).apply_to_model(model_ir) + print(f"Applied {count} pattern rewrite rules.") + model = protobuilder.build_model_proto(model_ir) + return model diff --git a/onnxscript/onnxrewriter/rewriter/broadcast_to_matmul.py b/onnxscript/onnxrewriter/rewriter/broadcast_to_matmul.py new file mode 100644 index 0000000000..724645242f --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/broadcast_to_matmul.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np + +from onnxrewriter import ir +from onnxrewriter.rewriter import pattern + +op = pattern.onnxop +logger = logging.getLogger(__name__) + + +# condition to check if we need to replace the pattern +def check_if_need_reshape(match_bindings: dict[str, ir.Value | Any]) -> bool: + """If matmul broadcasting is enough, then we don't need the reshapes. + + To validate this, we need to check the following: + 1. Input shapes check: input_a and input_b should be broadcastable + 2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b) + + If the above are true, then we don't need the reshapes. + + Args: + match_bindings: The match binding dictionary from a MatchResult. + + Returns: + bool: True if we need to replace the pattern, False otherwise. + + """ + input_a_shape = match_bindings["input_a"].shape + input_b_shape = match_bindings["input_b"].shape + shape_c = match_bindings["shape_c"].value_as_np_array + if shape_c is None: + return False + if not isinstance(shape_c, np.ndarray): + logger.info( + "Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c) + ) + return False + if len(shape_c.shape) != 1: + logger.info( + "Unexpected final shape. The shape of 'shape' value is %s", + shape_c.shape, + ) + return False + shape_c = shape_c.tolist() + + # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape + # information. So, we need to check if the shape is None and return False. + if input_a_shape is None or input_b_shape is None or shape_c is None: + logger.info("Shape information is not available for the inputs and outputs.") + return False + + dim_a = len(input_a_shape) + dim_b = len(input_b_shape) + + # 1. Check if input shapes are broadcastable + # 1.a. If the first input is 1-D, check whether + # the dim matches the last second dim of the second input. + mimic_matmul_broadcast_behavior = False + if dim_a < 2: # noqa: PLR2004 + if input_a_shape[-1] != input_b_shape[-2]: + logger.info("Original shape is not MatMul compatible.") + return False + else: + input_a_shape = [1, *input_a_shape] + dim_a = len(input_a_shape) + mimic_matmul_broadcast_behavior = True + # 1.b. If the second input is 1-D, check whether + # the dim matches the last dim of the first input. + if dim_b < 2: # noqa: PLR2004 + if input_b_shape[-1] != input_a_shape[-1]: + logger.info("Original shape is not MatMul compatible.") + return False + else: + input_b_shape = [*input_b_shape, 1] + dim_b = len(input_b_shape) + mimic_matmul_broadcast_behavior = True + # 1.c. If both inputs are at least 2-D, check whether + # the last dimension of the first input matches the second + # last dimension of the second input, and shape[:-2] are + # broadcastable. + input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] + input_b_shape_except_last_dim = input_b_shape[:-1] + broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] + for idx, (dim_from_a, dim_from_b) in enumerate( + zip( + reversed(input_a_shape_except_second_last_dim), + reversed(input_b_shape_except_last_dim), + ) + ): + if dim_from_a not in {1, dim_from_b}: + logger.info("Original shape is not broadcastable.") + return False + elif idx > 0: + broadcast_matmul_output_shape = [ + max(dim_from_a, dim_from_b), + *broadcast_matmul_output_shape, + ] + + # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) + # Prepend the broadcast_matmul_output_shape with the longer shape of input + if dim_a > dim_b: + longer_shape = input_a_shape + shorter_shape = input_b_shape + else: + longer_shape = input_b_shape + shorter_shape = input_a_shape + broadcast_matmul_output_shape = ( + longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape + ) + if mimic_matmul_broadcast_behavior and dim_b == 2: # noqa: PLR2004 + broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] + if mimic_matmul_broadcast_behavior and dim_a == 2: # noqa: PLR2004 + broadcast_matmul_output_shape.pop(-2) + if shape_c != broadcast_matmul_output_shape: + logger.info( + "Final output shape is not the same. Expected %s vs actual %s", + shape_c, + broadcast_matmul_output_shape, + ) + return False + + return True + + +def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c): + # TODO: Modified from `value_ints` to `value` to match pattern in benchmark models. + # This implementation misses pattern of Constants with `value_ints` attribute. + # See more at https://github.com/microsoft/onnx-rewriter/issues/191. + # A better solution is to improve pattern matching and avoid depending on writing + # Constants in pattern. See https://github.com/microsoft/onnx-rewriter/issues/192. + reshape_a = op.Reshape(input_a, shape_a) + reshape_b = op.Reshape(input_b, shape_b) + matmul = op.MatMul(reshape_a, reshape_b) + return op.Reshape(matmul, shape_c) + + +def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c): + del shape_a # Unused + del shape_b # Unused + del shape_c # Unused + return op.MatMul(input_a, input_b) + + +def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c): + reshape_a = op.Reshape(input_a, shape_a) + matmul = op.MatMul(reshape_a, input_b) + return op.Reshape(matmul, shape_c) + + +def matmul_with_one_shape_input(input_a, input_b, shape_a, shape_c): + del shape_a # Unused + del shape_c # Unused + return op.MatMul(input_a, input_b) + + +# Register the rewrite rules +two_reshapes_matmul_reshape_rule = pattern.RewriteRule( + two_reshapes_matmul_reshape_pattern, + matmul_with_two_shape_inputs, + check_if_need_reshape, +) +one_reshape_matmul_reshape_rule = pattern.RewriteRule( + one_reshape_matmul_reshape_pattern, + matmul_with_one_shape_input, + # We can use the same check_if_need_reshape function for both the rules, + # as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern. + check_if_need_reshape, +) + +# NOTE: The order of the rules is important. Larger pattern should be checked first. +rules = pattern.RewriteRuleSet( + [two_reshapes_matmul_reshape_rule, one_reshape_matmul_reshape_rule] +) diff --git a/onnxscript/onnxrewriter/rewriter/broadcast_to_matmul_test.py b/onnxscript/onnxrewriter/rewriter/broadcast_to_matmul_test.py new file mode 100644 index 0000000000..d77f33d2db --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/broadcast_to_matmul_test.py @@ -0,0 +1,283 @@ +import unittest + +import onnx.parser + +from onnxrewriter.ir import irbuilder +from onnxrewriter.rewriter import broadcast_to_matmul + + +class TwoReshapesMatMulReshapeTest(unittest.TestCase): + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nested_function( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + output = afunction (input_x, input_y) + } + + afunction (input_x, input_y) => (output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_x", + onnx.TensorProto.FLOAT, + [1, 4, 512, 512], + ) + ) + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_y", onnx.TensorProto.FLOAT, [1, 4, 512, 64] + ) + ) + + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 4) + self.assertEqual(ir.functions[0].nodes[-1].op_type, "MatMul") + + def test_reshape_matmul_reshape_remain_when_input_last_dim_and_second_last_dim_not_matched( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[512, 512, 4] input_x, float[4, 64, 512] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_remain_when_inputs_are_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 8, 512, 64] input_x, float[4, 4, 64, 512] input_y) => (float[2, 8, 512, 512] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_replace_when_inputs_are_broadcastable_with_one_in_dims( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 8, 512, 64] input_x, float[1, 1, 2, 8, 64, 512] input_y) => (float[1, 1, 2, 8, 512, 512] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[4] input_x, float[2, 3, 4, 5] input_y) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[8] input_x, float[2, 3, 4, 5] input_y) => (float[2, 3, 2, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_replace_when_second_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 4, 5] input_x, float[5] input_y) => (float[2, 3, 4] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_remain_when_second_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 4, 5] input_x, float[10] input_y) => (float[2, 3, 4, 2] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_remain_when_output_is_not_matmul_broadcasted( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 4, 5] input_x, float[5, 8] input_y) => (float[2, 4, 6, 4] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + +class OneReshapeMatMulReshapeTest(unittest.TestCase): + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 512, 4096] input_x, float[4096, 4096] input_y) => (float[1, 512, 4096] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + matmul = MatMul (reshape_x, input_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + # The constant nodes are not removed. They should be removed by a subsequent DCE in optimizer. + self.assertEqual(len(ir.graph.nodes), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/cast_constant_of_shape.py b/onnxscript/onnxrewriter/rewriter/cast_constant_of_shape.py new file mode 100644 index 0000000000..a4aa054dd8 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/cast_constant_of_shape.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import logging +from typing import Any, Sequence + +import numpy as np +import onnx + +from onnxrewriter import ir +from onnxrewriter.rewriter import pattern + +op = pattern.onnxop +logger = logging.getLogger(__name__) + + +def cast_constant_of_shape( + shape: Sequence[int], + t: Any, + dtype: int, + match_bindings: dict[str, ir.Value | Any] | None = None, # noqa: ARG001 +) -> pattern.OpPattern: + constant = op.ConstantOfShape(shape, value=t) + return op.Cast(constant, to=dtype) + + +def fused_cast_constant_of_shape( + shape: Sequence[int], t: Any, dtype: int, match_bindings: dict[str, ir.Value | Any] +) -> pattern.OpPattern: + del dtype # unused + del t # unused + v_dtype = match_bindings["dtype"] + v_t = match_bindings["t"] + casted_val = onnx.numpy_helper.to_array(v_t).astype( # type: ignore[arg-type] + dtype=onnx.helper.tensor_dtype_to_np_dtype(v_dtype) # type: ignore[arg-type] + ) + return op.ConstantOfShape(shape, value=casted_val) + + +def cast_constant_of_shape_without_value( + shape: Sequence[int], + dtype: int, + match_bindings: dict[str, ir.Value | Any] | None = None, +) -> pattern.OpPattern: + del match_bindings # Unused + constant = op.ConstantOfShape(shape) + return op.Cast(constant, to=dtype) + + +def fused_cast_constant_of_shape_without_value( + shape: Sequence[int], dtype: int, match_bindings: dict[str, ir.Value | Any] +) -> pattern.OpPattern: + del dtype # Unused + v_dtype = match_bindings["dtype"] + val = np.zeros(1, dtype=onnx.helper.tensor_dtype_to_np_dtype(v_dtype)) # type: ignore + return op.ConstantOfShape(shape, value=val) + + +cast_constant_of_shape_rule = pattern.RewriteRule( + cast_constant_of_shape, + pattern.ReplacementPatternFunction(fused_cast_constant_of_shape, delay_run=True), +) + +cast_constant_of_shape_without_value_rule = pattern.RewriteRule( + cast_constant_of_shape_without_value, + pattern.ReplacementPatternFunction( + fused_cast_constant_of_shape_without_value, delay_run=True + ), +) + +rules = pattern.RewriteRuleSet( + [ + cast_constant_of_shape_rule, + cast_constant_of_shape_without_value_rule, + ] +) diff --git a/onnxscript/onnxrewriter/rewriter/cast_constant_of_shape_test.py b/onnxscript/onnxrewriter/rewriter/cast_constant_of_shape_test.py new file mode 100644 index 0000000000..d0cfd9a005 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/cast_constant_of_shape_test.py @@ -0,0 +1,46 @@ +import unittest + +import onnx.parser + +from onnxrewriter.ir import irbuilder +from onnxrewriter.rewriter import cast_constant_of_shape + + +class CastConstantOfShapeTest(unittest.TestCase): + def test_cast_after_constant_of_shape_is_fused(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[2] input_x) => (float16[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + ir = irbuilder.build_ir(model) + count = cast_constant_of_shape.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + self.assertEqual(ir.graph.nodes[0].attributes["value"].data_type, 10) + + def test_cast_after_constant_of_shape_without_value_is_fused(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[2] input_x) => (float16[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + ir = irbuilder.build_ir(model) + count = cast_constant_of_shape.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + self.assertEqual(ir.graph.nodes[0].attributes["value"].data_type, 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/erfgelu.py b/onnxscript/onnxrewriter/rewriter/erfgelu.py new file mode 100644 index 0000000000..9a8ad04dbf --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/erfgelu.py @@ -0,0 +1,30 @@ +import math + +from onnxrewriter.rewriter import pattern + +op = pattern.onnxop + + +# Pattern to match against +def erf_gelu_pattern(x): + # erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + # half = pattern.Constant(0.5) + # sqrt2 = pattern.Constant(1.4142) + # x_div_sqrt2 = op.Div(x, sqrt2) + # erf = op.Erf(x_div_sqrt2) + # one = pattern.Constant(1.0) + # one_plus_erf = op.Add(erf, one) + # x_mul_one_plus_erf = op.Mul(x, one_plus_erf) + # return op.Mul(half, x_mul_one_plus_erf) + return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) + + +msft_op = pattern.msft_op + + +# Replacement +def gelu(x): + return msft_op.Gelu(x) + + +rule = pattern.RewriteRule(erf_gelu_pattern, gelu) diff --git a/onnxscript/onnxrewriter/rewriter/function_rule.py b/onnxscript/onnxrewriter/rewriter/function_rule.py new file mode 100644 index 0000000000..065acc609d --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/function_rule.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import functools +import logging + +import onnx +import onnxscript +from packaging import version + +from onnxrewriter import ir +from onnxrewriter.ir import visitor +from onnxrewriter.rewriter import pattern + +logger = logging.getLogger(__name__) + + +class FunctionRewriteError(RuntimeError): ... + + +@functools.lru_cache +def parse_domain(function_domain: str) -> tuple[str, version.Version | None]: + splits = function_domain.split(".") + if splits[0] != "pkg": + raise FunctionRewriteError( + f"Invalid domain: {function_domain}. Must start with 'pkg'." + ) + splits = splits[1:] + for i, s in enumerate(splits): + if s.isdigit(): + return ".".join(splits[:i]), version.parse(".".join(splits[i:])) + return ".".join(splits), None + + +MIN_VERSION = version.parse("0") +MAX_VERSION = version.parse("9999") + + +class VersionController: + def __init__(self): + # A dispatch table for rewrite implementation based on the function package version. + self.dispatch_table: dict[ + tuple[version.Version, version.Version], callable + ] = {} + + def register_version( + self, + min_version: version.Version | str | None = None, + max_version: version.Version | str | None = None, + ): + """Register a function implementation for a specific package version range [min_version, max_version). + + Args: + min_version: The minimum version of the package. Inclusive. + max_version: The maximum version of the package. Exclusive. + """ + # TODO: check for version overloap + + min_version = MIN_VERSION if min_version is None else min_version + max_version = MAX_VERSION if max_version is None else max_version + if isinstance(min_version, str): + min_version = version.parse(min_version) + if isinstance(max_version, str): + max_version = version.parse(max_version) + + def deco(func): + self.dispatch_table[(min_version, max_version)] = func + return func + + return deco + + def dispatch(self, version: version.Version | None) -> callable | None: + if version is None: + if len(self.dispatch_table) == 1: + return next(iter(self.dispatch_table.values())) + raise ValueError( + "No function package version specified, however there are multiple " + f"fusion rules based on package version: {self.dispatch_table.keys()}." + ) + for (min_version, max_version), func in self.dispatch_table.items(): + greater_than_min = min_version is None or min_version <= version + less_than_max = max_version is None or version < max_version + if greater_than_min and less_than_max: + return func + return None + + +class FunctionRewriteRule(pattern.RewriteRule): + FUNCTION_KEYWORD: str | tuple[str] + """The keyword to match the function name. If a tuple, any keyword will match.""" + + PACKAGE_NAME: str + """The package name to match. + + For example, 'transformers' to match for domain name 'pkg.transformers.4.36.2'. + """ + + _opset_imports: dict[str, int] + onnx_opset: onnxscript.values.Opset + _function_shape_env: visitor.FunctionShapeEnv + + def __init__(self, opset: onnxscript.values.Opset = onnxscript.opset18) -> None: + self.onnx_opset = opset + + def _match_function(self, function: onnx.FunctionProto, pkg_name: str) -> bool: + # TODO: Consolidate more checks from `compose_new_function` to here. + if pkg_name != self.PACKAGE_NAME: + logger.info( + "Rule %s did not match function %s::%s. Package name mismatch '%s' != '%s'.", + self.__class__.__name__, + function.domain, + function.name, + self.PACKAGE_NAME, + pkg_name, + ) + return False + + if isinstance(self.FUNCTION_KEYWORD, str): + return function.name.find(self.FUNCTION_KEYWORD) != -1 + elif isinstance(self.FUNCTION_KEYWORD, tuple): + return any( + function.name.find(keyword) != -1 for keyword in self.FUNCTION_KEYWORD + ) + else: + raise ValueError( # noqa: TRY004 + f"Function keyword must be str or tuple, got {self.FUNCTION_KEYWORD}" + ) + + def _find_node_contains_key_in_name( + self, function: onnx.FunctionProto, keyword: str + ) -> onnx.NodeProto | None: + for node in function.node: + if node.name.find(keyword) != -1: + return node + return None + + def _find_node_by_type( + self, function: onnx.FunctionProto, domain: str, op_type: str + ) -> onnx.NodeProto | None: + # Repeat + for node in function.node: + if node.domain == domain and node.op_type == op_type: + return node + return None + + def _find_constant_node( + self, function: onnx.FunctionProto, value_name: str + ) -> onnx.NodeProto | None: + # Potentially repeat, utility function. + for node in function.node: + for output in node.output: + if output == value_name: + return node + return None + + def compose_new_function( + self, old_function: onnx.FunctionProto, pkg_version: version.Version | None + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + """Compose a new function from the old function. + + Returns: + A tuple of the new function and the opset imports. + + Raises: + FunctionRewriteError: If the rewrite fails. + """ + func = self._version_controller.dispatch(pkg_version) + if func is not None: + return func(self, old_function) + raise FunctionRewriteError( + f"No rewrite implementation for package version {pkg_version}." + ) + + def try_rewrite_function( + self, function: onnx.FunctionProto, model: onnx.ModelProto + ) -> bool: + try: + pkg_name, pkg_version = parse_domain(function.domain) + except FunctionRewriteError as e: + logger.warning("Could not parse domain: %s", e) + return False + + if pkg_version is None and not pkg_name.startswith("onnxscript"): + logger.warning( + "Could not parse version for domain of function %s::%s. " + "Usually this implies the model source is not from a package, but from arbitrary python files instead. " + "For example, models not defined in huggingface/transformers but loaded via 'trust_remote_code=True'.", + function.domain, + function.name, + ) + + if not self._match_function(function, pkg_name): + return False + logger.info( + "Rule %s matched function %s::%s", + self.__class__.__name__, + function.domain, + function.name, + ) + + try: + new_function, opset_imports = self.compose_new_function( + function, pkg_version + ) + except FunctionRewriteError as e: + logger.warning("Could not rewrite function: %s", e) + return False + + nodes = new_function.node + + del function.input[:] + function.input.extend(new_function.input) + del function.output[:] + function.output.extend(new_function.output) + + del function.node[:] + function.node.extend(nodes) + for new_opset in opset_imports: + function.opset_import.append(new_opset) + if new_opset.domain not in self._opset_imports: + model.opset_import.append(new_opset) + + return True + + def try_rewrite(self, model: ir.Model, value) -> bool: + raise NotImplementedError( + "Use `try_rewrite_function` instead for function based rewrites." + ) + + def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | None: + return self._function_shape_env.lookup(function, value_name) + + def apply_to_model(self, model: ir.Model, *, commute: bool = False) -> int: + del commute # unused + model_proto: onnx.ModelProto = model.original_model_proto + self._function_shape_env = visitor.FunctionShapeEnv() + self._function_shape_env.load_from_model_proto(model.original_model_proto) + self._opset_imports = {x.domain: x.version for x in model_proto.opset_import} + + rewrite_count = 0 + for function in model_proto.functions: + rewrite_count += self.try_rewrite_function(function, model_proto) + return rewrite_count + + def count_matches(self, model, *, commute: bool = False) -> int: + raise NotImplementedError() + + def commute(self) -> list[pattern.RewriteRule]: + raise NotImplementedError() diff --git a/onnxscript/onnxrewriter/rewriter/gemm_to_matmul_add.py b/onnxscript/onnxrewriter/rewriter/gemm_to_matmul_add.py new file mode 100644 index 0000000000..62521566a8 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/gemm_to_matmul_add.py @@ -0,0 +1,23 @@ +from onnxrewriter.rewriter import pattern +from onnxrewriter.rewriter.broadcast_to_matmul import check_if_need_reshape + +op = pattern.onnxop + + +# Pattern to match against +def reshape_gemm_reshape_pattern(input_a, input_b, input_c, shape_a, shape_c): + reshape_a = op.Reshape(input_a, shape_a) + # TODO: Temporary workaround to support benchmodels. + # Tracked by https://github.com/microsoft/onnx-rewriter/issues/197. + gemm = op.Gemm(reshape_a, input_b, input_c, alpha=1.0, beta=1.0) + return op.Reshape(gemm, shape_c) + + +def matmul_add(input_a, input_b, input_c, shape_a, shape_d): # noqa: ARG001 + matmul = op.MatMul(input_a, input_b) + return op.Add(matmul, input_c) + + +rule = pattern.RewriteRule( + reshape_gemm_reshape_pattern, matmul_add, check_if_need_reshape +) diff --git a/onnxscript/onnxrewriter/rewriter/gemm_to_matmul_add_test.py b/onnxscript/onnxrewriter/rewriter/gemm_to_matmul_add_test.py new file mode 100644 index 0000000000..8a4d194f4f --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/gemm_to_matmul_add_test.py @@ -0,0 +1,254 @@ +import unittest + +import onnx.parser + +from onnxrewriter.ir import irbuilder +from onnxrewriter.rewriter import gemm_to_matmul_add + + +class ReshapeGemmReshapeTest(unittest.TestCase): + def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[4, 512, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable_in_nested_function( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[4, 512, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + output = afunction (input_x, input_y, input_z) + } + + afunction (input_x, input_y, input_z) => (output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_x", + onnx.TensorProto.FLOAT, + [1, 4, 512, 512], + ) + ) + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_y", onnx.TensorProto.FLOAT, [4, 512, 64] + ) + ) + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_z", onnx.TensorProto.FLOAT, [1, 4, 512, 64] + ) + ) + + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 4) + self.assertEqual(ir.functions[0].nodes[2].op_type, "MatMul") + self.assertEqual(ir.functions[0].nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_remain_when_input_last_dim_and_second_last_dim_not_matched( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[4, 256, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_remain_when_inputs_are_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 2, 512, 512] input_x, float[4, 512, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_replace_when_inputs_are_broadcastable_with_one_in_dims( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[4, 512, 512] input_x, float[1, 4, 512, 64] input_y, float[1, 4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + self.assertEqual(ir.graph.nodes[2].op_type, "MatMul") + self.assertEqual(ir.graph.nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[4] input_x, float[2, 3, 4, 5] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + self.assertEqual(ir.graph.nodes[2].op_type, "MatMul") + self.assertEqual(ir.graph.nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[8] input_x, float[2, 3, 4, 5] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 5, 4] input_x, float[4] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + self.assertEqual(ir.graph.nodes[2].op_type, "MatMul") + self.assertEqual(ir.graph.nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 5, 4] input_x, float[10] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 5, 4] input_x, float[5] input_y, float[2, 3, 5] input_z) => (float[2, 4, 6] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/no_op.py b/onnxscript/onnxrewriter/rewriter/no_op.py new file mode 100644 index 0000000000..c1e6cf08fd --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/no_op.py @@ -0,0 +1,44 @@ +from onnxrewriter.rewriter import pattern + +op = pattern.onnxop + +# TODO: Support 1-D constant tensors +# https://github.com/microsoft/onnx-rewriter/issues/186 + + +# Pattern to match against +def mul_by_1(x): + return x * 1 + + +def add_0(x): + return x + 0 + + +def sub_0(x): + return x - 0 + + +def div_by_1(x): + return x / 1 + + +# Replacement +def identity(x): + return op.Identity(x) + + +mul_by_1_rule = pattern.RewriteRule(mul_by_1, identity) +add_0_rule = pattern.RewriteRule(add_0, identity) +sub_0_rule = pattern.RewriteRule(sub_0, identity) +div_by_1_rule = pattern.RewriteRule(div_by_1, identity) +# TODO: Include Mul by 0, 0 by Mul, 0 by Div? Those would be 0s, but not no-ops + +rules = pattern.RewriteRuleSet( + [ + *mul_by_1_rule.commute(), + *add_0_rule.commute(), + sub_0_rule, + div_by_1_rule, + ] +) diff --git a/onnxscript/onnxrewriter/rewriter/no_op_test.py b/onnxscript/onnxrewriter/rewriter/no_op_test.py new file mode 100644 index 0000000000..5bd2fbb98b --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/no_op_test.py @@ -0,0 +1,180 @@ +import unittest + +import onnx.parser +import parameterized + +from onnxrewriter.ir import irbuilder +from onnxrewriter.rewriter import no_op + + +class NoOpTest(unittest.TestCase): + def _check(self, model_text: str) -> None: + model = onnx.parser.parse_model(model_text) + ir = irbuilder.build_ir(model) + count = no_op.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(ir.graph.nodes[-1].op_type, "Identity") + + @parameterized.parameterized.expand( + [ + ("float one input", "float[M]", "value_float=1.0", "one, input"), + ("int one input", "int32[M]", "value_int=1", "one, input"), + ("float input one", "float[M]", "value_float=1.0", "input, one"), + ("int input one", "int32[M]", "value_int=1", "input, one"), + ] + ) + def test_mul_one_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + one = Constant<{constant_value}>() + output = Mul({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float one input", "float[M]", "float one = {1.0}", "one, input"), + ("int one input", "int32[M]", "int32 one = {1}", "one, input"), + ("float input one", "float[M]", "float one = {1.0}", "input, one"), + ("int input one", "int32[M]", "int32 one = {1}", "input, one"), + ] + ) + def test_mul_one_should_become_no_op_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Mul({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float zero input", "float[M]", "value_float=0.0", "zero, input"), + ("int zero input", "int32[M]", "value_int=0", "zero, input"), + ("float input zero", "float[M]", "value_float=0.0", "input, zero"), + ("int input zero", "int32[M]", "value_int=0", "input, zero"), + ] + ) + def test_add_zero_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + zero = Constant<{constant_value}>() + output = Add({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input zero", "float[M]", "float zero = {0.0}", "input, zero"), + ("int input zero", "int32[M]", "int32 zero = {0}", "input, zero"), + ("float input zero", "float[M]", "float zero = {0.0}", "input, zero"), + ("int input zero", "int32[M]", "int32 zero = {0}", "input, zero"), + ] + ) + def test_add_zero_should_become_no_op_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Add({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input zero", "float[M]", "value_float=0.0", "input, zero"), + ("int input zero", "int32[M]", "value_int=0", "input, zero"), + ] + ) + def test_sub_zero_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + zero = Constant<{constant_value}>() + output = Sub({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input zero", "float[M]", "float zero = {0.0}", "input, zero"), + ("int input zero", "int32[M]", "int32 zero = {0}", "input, zero"), + ] + ) + def test_sub_zero_should_become_no_op_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Sub({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input one", "float[M]", "value_float=1.0", "input, one"), + ("int input one", "int32[M]", "value_int=1", "input, one"), + ] + ) + def test_div_one_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + one = Constant<{constant_value}>() + output = Div({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input one", "float[M]", "float one = {1.0}", "input, one"), + ("int input one", "int32[M]", "int32 one = {1}", "input, one"), + ] + ) + def test_div_one_should_become_no_op_with_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Div({input_order}) + }} + """ + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/__init__.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/__init__.py new file mode 100644 index 0000000000..c3d2e2aaab --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/__init__.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import onnx + +from onnxrewriter.ir import irbuilder, protobuilder +from onnxrewriter.optimizer import remove_unused +from onnxrewriter.rewriter import function_rule, pattern +from onnxrewriter.rewriter.onnxruntime import ( + instance_to_group_normalization, + softmax, + transformers, +) + +ORT_FUNCTION_REWRITE_RULES = [*transformers.TRANSFORMERS_FUNCTION_REWRITE_RULES] + +ORT_PATTERN_REWRITE_RULES = [ + *softmax.rules.rules, + *instance_to_group_normalization.rules.rules, +] + + +def rewrite( + model: onnx.ModelProto, + function_rules: list[type[function_rule.FunctionRewriteRule]] | None = None, + pattern_rules: list[pattern.RewriteRule] | None = None, +) -> onnx.ModelProto: + """Rewrite the model using the given rules. + + Args: + model: The model to rewrite. + function_rules: The function rewrite rules to apply. If None, the default rules + for onnxruntime are used. + pattern_rules: The pattern rewrite rules to apply. If None, the default rules + for onnxruntime are used. + + Returns: + The rewritten model. + """ + function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES + pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES + # TODO: Function rules first, or pattern rules first? + if function_rules: + model_ir = irbuilder.build_ir(model) + for rule_cls in function_rules: + rule_cls().apply_to_model(model_ir) + model = model_ir.original_model_proto + if pattern_rules: + model_ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model_ir) + print(f"Applied {count} pattern rewrite rules.") + model = protobuilder.build_model_proto(model_ir) + remove_unused.remove_unused_nodes(model) + return model diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/instance_to_group_normalization.py new file mode 100644 index 0000000000..fbc5411d91 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/instance_to_group_normalization.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import onnx + +from onnxrewriter import ir +from onnxrewriter.rewriter import pattern + +op = pattern.onnxop +msft_op = pattern.msft_op + +logger = logging.getLogger(__name__) + + +def _check_if_simulated_instance_norm_is_used_impl( + input_x, + adjusted_input_shape, + original_input_shape, + weight_for_norm, + bias_for_norm, + weight_full, + bias_full, + **kwargs, # noqa: ARG001 +) -> bool: + if not np.all(weight_for_norm.value_as_np_array == 1): + return False + if not np.all(bias_for_norm.value_as_np_array == 0): + return False + + input_rank_minus_one = len(input_x.shape) - 1 + weight_full_rank = len(weight_full.shape) + bias_full_rank = len(bias_full.shape) + if ( + weight_full_rank != input_rank_minus_one + or bias_full_rank != input_rank_minus_one + ): + return False + + input_rank = len(input_x.shape) + if input_rank != 4: # noqa: PLR2004 + return False + + weight_full_shape = weight_full.shape + if not all(dim == 1 for dim in weight_full_shape[1:]): + return False + bias_full_shape = bias_full.shape + if not all(dim == 1 for dim in bias_full_shape[1:]): + return False + + adjusted_input_shape = adjusted_input_shape.value_as_np_array + g = weight_for_norm.shape[0] + if adjusted_input_shape is None or adjusted_input_shape.tolist() != [0, g, -1]: + return False + + # NOTE: Restrict the rule to only support constant shape + original_input_shape = original_input_shape.value_as_np_array + if original_input_shape is None or original_input_shape.tolist() != input_x.shape: + return False + + return True + + +def check_if_simulated_instance_norm_is_used( + match_bindings: dict[str, ir.Value | Any], +) -> bool: + """Check if the simulated instance normalization is used. + + In torchlib with opset18, onnx.GroupNorm is using wrong definition, so + we use InstanceNormalization to simulate GroupNormalization. We need to check if there are arguments created to simulation. + If there are, then we need to replace the pattern. If they are not used, then we don't need to replace the pattern. + + To validate this, we need to check the following: + 1. weight_for_norm are all 1 and bias_for_norm are all 0, as they are created for the simulation. + 2. weight_full and bias_full are unsqueezed to be easily broadcastable. + 3. input rank should be 4 + 4. weight_full and bias_full should have ones except first dim. + 5. adjusted_input_shape is a constant tensor of form [0, g, -1] + 6. original_input_shape is the same as input_x shape. + + Args: + match_bindings: The match binding dictionary from a MatchResult. + + Returns: + bool: True if the simulated instance normalization is used, False otherwise. + """ + return _check_if_simulated_instance_norm_is_used_impl(**match_bindings) + + +def instance_simulates_group_normalization_pattern( + input_x, + adjusted_input_shape, + original_input_shape, + weight_for_norm, + bias_for_norm, + weight_full, + bias_full, + epsilon, + match_bindings: dict[str, ir.Value | Any] | None = None, # noqa: ARG001 +): + adjusted_input = op.Reshape(input_x, adjusted_input_shape) + inst_norm = op.InstanceNormalization( + adjusted_input, weight_for_norm, bias_for_norm, epsilon=epsilon + ) + adjusted_inst_norm = op.Reshape(inst_norm, original_input_shape) + mul = op.Mul(adjusted_inst_norm, weight_full) + return op.Add(mul, bias_full) + + +def group_normalization( + input_x, + adjusted_input_shape, # noqa: ARG001 + original_input_shape, # noqa: ARG001 + weight_for_norm, # noqa: ARG001 + bias_for_norm, # noqa: ARG001 + weight_full, + bias_full, + epsilon, + match_bindings: dict[str, ir.Value | Any] | None = None, +): + # com.microsoft.GroupNorm only supports NHWC for now + nhwc_input = op.Transpose(input_x, perm=[0, 2, 3, 1]) + # com.microsoft.GroupNorm only supports gamma and beta as float type + weight_full = op.Cast(weight_full, to=onnx.TensorProto.FLOAT) + reshape_to_1d = op.Constant(value_ints=[-1]) + weight_full = op.Reshape(weight_full, reshape_to_1d) + bias_full = op.Cast(bias_full, to=onnx.TensorProto.FLOAT) + bias_full = op.Reshape(bias_full, reshape_to_1d) + # re-obtain attribute groups + groups = match_bindings["weight_for_norm"].shape[0] + output = msft_op.GroupNorm( + nhwc_input, + weight_full, + bias_full, + activation=0, + channels_last=1, + epsilon=epsilon, + groups=groups, + ) + return op.Transpose(output, perm=[0, 3, 1, 2]) + + +# Register the rewrite rules +instance_norm_to_group_norm_rule = pattern.RewriteRule( + instance_simulates_group_normalization_pattern, + pattern.ReplacementPatternFunction(group_normalization, delay_run=True), + check_if_simulated_instance_norm_is_used, +) + +rules = pattern.RewriteRuleSet([instance_norm_to_group_norm_rule]) diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/instance_to_group_normalization_test.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/instance_to_group_normalization_test.py new file mode 100644 index 0000000000..1e12dff6ce --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/instance_to_group_normalization_test.py @@ -0,0 +1,435 @@ +import unittest + +import numpy as np +import onnx.parser + +from onnxrewriter.ir import irbuilder +from onnxrewriter.rewriter.onnxruntime import instance_to_group_normalization + + +class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase): + def _set_up_model_initializers( + self, + model, + weight_for_norm_value, + weight_for_norm_shape, + bias_for_norm_value, + bias_for_norm_shape, + weight_full_value, + weight_full_shape, + bias_full_value, + bias_full_shape, + ): + """Set up the model initializers for the test.""" + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "weight_for_norm", + onnx.TensorProto.FLOAT16, + weight_for_norm_shape, + weight_for_norm_value, + ), + onnx.helper.make_tensor( + "bias_for_norm", + onnx.TensorProto.FLOAT16, + bias_for_norm_shape, + bias_for_norm_value, + ), + onnx.helper.make_tensor( + "weight_full", + onnx.TensorProto.FLOAT16, + weight_full_shape, + weight_full_value, + ), + onnx.helper.make_tensor( + "bias_full", + onnx.TensorProto.FLOAT16, + bias_full_shape, + bias_full_value, + ), + ] + ) + + def test_simulated_instance_norm_is_replaced_by_group_norm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 1) + # plus 2 in model constants + self.assertEqual(len(ir.graph.nodes), 10) + + def test_instance_norm_with_non_one_weight_for_norm_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.random.rand(32).astype(np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_zero_b_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.random.rand(32).astype(np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_broadcasted_weight_full_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_broadcasted_bias_full_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_rank_not_4_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_weight_full_having_multiple_not_one_dim_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 2, 3).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 2, 3], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_bias_full_having_multiple_not_one_dim_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 2, 3).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 2, 3], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_not_0_g_negative_1_shape_of_adjusted_input_shape_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_equal_of_image_shape_and_original_input_shape_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/softmax.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/softmax.py new file mode 100644 index 0000000000..eca6dde17f --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/softmax.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import logging +from typing import Any + +import onnx + +from onnxrewriter import ir +from onnxrewriter.rewriter import pattern + +op = pattern.onnxop +logger = logging.getLogger(__name__) + + +def softmax_with_fp32_upcast(input, axis): + upcast = op.Cast(input, to=onnx.TensorProto.FLOAT) + softmax = op.Softmax(upcast, axis=axis) + return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) + + +def softmax(input, axis): + return op.Softmax(input, axis=axis) + + +def softmax_with_fp32_upcast_without_axis(input): + upcast = op.Cast(input, to=onnx.TensorProto.FLOAT) + softmax = op.Softmax(upcast) + return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) + + +def softmax_without_axis(input): + return op.Softmax(input) + + +def check_if_fp16_input(match_bindings: dict[str, ir.Value | Any]) -> bool: + input_val = match_bindings.get("input") + if input_val is None: + logger.warning( + "Cannot perform softmax upcast removal: " + "cannot retrieve match_bindings for 'input' for dtype validation." + ) + return False + return input_val.element_type == onnx.TensorProto.FLOAT16 + + +""" +This is an onnxruntime specific pattern. Softmax upcast is a common +pattern observed in transformers models to prevent overflow. However +this is not required since onnxruntime implementation already takes +overflow into account. Hence it is safe to remove the surrounding casts +to free up memory as well as saving performance. +""" +rules = pattern.RewriteRuleSet( + [ + pattern.RewriteRule(softmax_with_fp32_upcast, softmax, check_if_fp16_input), + pattern.RewriteRule( + softmax_with_fp32_upcast_without_axis, + softmax_without_axis, + check_if_fp16_input, + ), + ] +) diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/softmax_test.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/softmax_test.py new file mode 100644 index 0000000000..f32a1231b5 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/softmax_test.py @@ -0,0 +1,92 @@ +import unittest + +import onnx.parser +import parameterized + +from onnxrewriter.ir import irbuilder +from onnxrewriter.rewriter.onnxruntime import softmax + + +class SoftmaxUpcastRemovalTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("Softmax",), + ("Softmax",), + ] + ) + def test_softmax_upcast_to_fp32_is_removed_when_input_and_final_output_is_fp16( + self, softmax_op_str + ): + model = onnx.parser.parse_model( + f""" + + agraph (float16[N] x) => (float16[N] z) + {{ + x_fp32 = Cast(x) + z_fp32 = {softmax_op_str}(x_fp32) + z = Cast(z_fp32) + }} + """ + ) + ir = irbuilder.build_ir(model) + count = softmax.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertNotIn("Cast", {node.op_type for node in ir.graph.nodes}) + + @parameterized.parameterized.expand( + [ + ("Softmax",), + ("Softmax",), + ] + ) + def test_softmax_upcast_to_fp32_is_not_removed_when_input_is_not_fp16( + self, softmax_op_str + ): + model = onnx.parser.parse_model( + f""" + + agraph (int32[N] x) => (float16[N] z) + {{ + x_fp32 = Cast(x) + z_fp32 = {softmax_op_str}(x_fp32) + z = Cast(z_fp32) + }} + """ + ) + ir = irbuilder.build_ir(model) + count = softmax.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual( + len([node.op_type for node in ir.graph.nodes if node.op_type == "Cast"]), 2 + ) + + @parameterized.parameterized.expand( + [ + ("Softmax",), + ("Softmax",), + ] + ) + def test_softmax_upcast_to_fp32_is_not_removed_when_final_output_is_not_fp16( + self, softmax_op_str + ): + model = onnx.parser.parse_model( + f""" + + agraph (float16[N] x) => (double[N] z) + {{ + x_fp32 = Cast(x) + z_fp32 = {softmax_op_str}(x_fp32) + z = Cast(z_fp32) + }} + """ + ) + ir = irbuilder.build_ir(model) + count = softmax.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual( + len([node.op_type for node in ir.graph.nodes if node.op_type == "Cast"]), 2 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/__init__.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/__init__.py new file mode 100644 index 0000000000..334e153275 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/__init__.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from onnxrewriter.rewriter import function_rule +from onnxrewriter.rewriter.onnxruntime.transformers import ( + fastgelu, + layernorm, + multihead_attention, +) + +TRANSFORMERS_FUNCTION_REWRITE_RULES: list[type[function_rule.FunctionRewriteRule]] = [ + multihead_attention.GQALlama2RewriteRule, + multihead_attention.GQALlamaSdpa2RewriteRule, + multihead_attention.AttnPhi15RewriteRule, + layernorm.LNRewriteRule, + fastgelu.GeluRewriteRule, +] diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/fastgelu.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/fastgelu.py new file mode 100644 index 0000000000..d9a0b9a9ca --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/fastgelu.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import logging + +import onnx +import onnxscript + +from onnxrewriter.rewriter import function_rule + +logger = logging.getLogger(__name__) + + +class GeluRewriteRule(function_rule.FunctionRewriteRule): + FUNCTION_KEYWORD = "GELUActivation" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + @_version_controller.register_version() + def _fusion( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, list[onnx.OperatorSetIdProto]]: + del function # Unused + op = self.onnx_opset + msft_opset = onnxscript.values.Opset("com.microsoft", 1) + + def gelu(input): + return msft_opset.FastGelu(input) + + return onnxscript.script(default_opset=op)(gelu).to_function_proto(), ( + onnx.helper.make_operatorsetid("com.microsoft", 1), + ) diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/fastgelu_test.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/fastgelu_test.py new file mode 100644 index 0000000000..36fca2b9f8 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/fastgelu_test.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from tests import common + + +class FastGeluParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + def test_gelu_phi_1_5(self): + common.test_onnxruntime_rewrite( + "gelu_phi_1_5", 4, {("com.microsoft", "FastGelu", "")} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/layernorm.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/layernorm.py new file mode 100644 index 0000000000..ab59b70792 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/layernorm.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import logging + +import onnx +import onnxscript +from onnx import numpy_helper + +from onnxrewriter.rewriter import function_rule + +logger = logging.getLogger(__name__) + + +class LNRewriteRule(function_rule.FunctionRewriteRule): + FUNCTION_KEYWORD = "layernorm" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + @_version_controller.register_version() + def _fusion( # type: ignore[misc] + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, list[onnx.OperatorSetIdProto]]: + # TODO(bowbao): Might be more desirable to annotate as attribute in nn.Module + aten_add_node = self._find_node_by_type(function, "", "Add") + if aten_add_node is None: + raise function_rule.FunctionRewriteError("Could not find Add node") + + eps_node = self._find_constant_node(function, aten_add_node.input[1]) + if eps_node is None: + raise function_rule.FunctionRewriteError("Could not find eps node") + + eps = numpy_helper.to_array(eps_node.attribute[0].t).item() + logger.info("eps: %s", eps) + + # TODO(ORT): SimplifiedLayerNormalization in ort is defined under onnx domain. + # https://github.com/microsoft/onnxruntime/issues/7573 + # msft_op = onnxscript.values.Opset("com.microsoft", 1) + op = self.onnx_opset + + def ln(input, weight): + return op.SimplifiedLayerNormalization( + input, weight, axis=-1, epsilon=eps, stash_type=1 + ) + + return onnxscript.script(default_opset=op)(ln).to_function_proto(), [] diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/layernorm_test.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/layernorm_test.py new file mode 100644 index 0000000000..ef883290bf --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/layernorm_test.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from tests import common + + +class LNParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + def test_ln_llama2(self): + common.test_onnxruntime_rewrite( + "ln_llama2", 4, {("", "SimplifiedLayerNormalization", "")} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/multihead_attention.py new file mode 100644 index 0000000000..94246b376f --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/multihead_attention.py @@ -0,0 +1,642 @@ +r"""POC experimenting function aware pattern re-write. + +In this case we don't want to spell-out the entire source pattern. +Instead, we want to replace an entire function call a new subgraph. + +Source function: LlamaAttention +inputs (positional args, the names in function definition are unfortunately arbitrary and don't provide value): + - hidden_states + - position_id + - attention_mask + - q_proj.weight + - k_proj.weight + - v_proj.weight + - cos_cached + - sin_cached + - o_proj.weight +outputs (similarly, positional) + - present_value + - present_key + - attn_output (o_proj) + +The rewriting algorithm is as follows: + +The final new function graph should look like this: + + function_proj_q function_proj_k + | | + | | +com.microsoft::RotaryEmbedding com.microsoft::RotaryEmbedding function_proj_v + \ / / + \ / / + \ / / + \--------------- / -----------------------/ + com.microsoft::MultiHeadAttention + | | | + attn_output (present_key) (present_value) + | + function_proj_o + | + (output) + +So all we need, is to locate 'function_proj_q', 'function_proj_k', 'function_proj_v', 'function_proj_o'. +Construct the 4 nodes with new contrib op nodes, and properly name their inputs/outputs. + +""" + +from __future__ import annotations + +import abc +import dataclasses +import logging + +import onnx +import onnxscript +from onnx import helper as onnx_helper + +from onnxrewriter.rewriter import function_rule + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class AttnSizeConfig: + num_attention_heads: int + num_key_value_heads: int + head_size: int + hidden_size: int + + +class AttentionRewriteRule(function_rule.FunctionRewriteRule, abc.ABC): + def infer_attn_size_config(self, function: onnx.FunctionProto) -> AttnSizeConfig: + if len(function.output) != 3: # noqa: PLR2004 + raise function_rule.FunctionRewriteError( + f"Unexpected number of outputs. Expected 3, got {len(function.output)}." + ) + present_value, _, attn_output = function.output + if ( + present_value_ir := self.lookup(function, present_value) + ) is None or present_value_ir.shape is None: + raise function_rule.FunctionRewriteError( + "Failed to find shape for present_value." + ) + if ( + attn_output_ir := self.lookup(function, attn_output) + ) is None or attn_output_ir.shape is None: + raise function_rule.FunctionRewriteError( + "Failed to find shape for attn_output." + ) + head_size = present_value_ir.shape[3] + num_key_value_heads = present_value_ir.shape[1] + hidden_size = attn_output_ir.shape[2] + num_attention_heads = hidden_size // head_size + return AttnSizeConfig( + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_size=head_size, + hidden_size=hidden_size, + ) + + +class MHALlama2RewriteRule(AttentionRewriteRule): + FUNCTION_KEYWORD = "LlamaAttention" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version(min_version="4.33", max_version="4.36") + def _fusion_with_4d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + if len(function.input) != 9: # noqa: PLR2004 + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + expand_shape = [1, attn_size_config.num_attention_heads, 1, 1] + + def mha( + hidden_states, + position_id, + attention_mask, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + # TODO(onnxscript) + # ValueError: ERROR: Unsupported expression type . + # at: Function 'mha', line 16 + # cos = op.Slice(op.Squeeze(cos_cached, [0, 1]), [0], [cos_sin_gather_size], [1]) + # NOTE: Depending on transformers version, the shape of cos/sin is different. + # In later version, the shape is [seq_len, head_size], so the Squeeze is not needed. + # In this version, the shape is [1, 1, seq_len, head_size], hence the below Squeeze. + cos = op.Slice( + op.Squeeze(cos_cached, [0, 1]), [0], cos_sin_gather_size, [1] + ) + sin = op.Slice( + op.Squeeze(sin_cached, [0, 1]), [0], cos_sin_gather_size, [1] + ) + + q_rope = msft_op.RotaryEmbedding( + q, position_id, cos, sin, interleaved=False + ) + k_rope = msft_op.RotaryEmbedding( + k, position_id, cos, sin, interleaved=False + ) + + # TODO(onnxscript) + # ValueError: ERROR: Unsupported expression type . + # expanded_mask = op.Expand(attention_mask, [1, self.num_heads, 1, 1]) + expanded_mask = op.Expand(attention_mask, expand_shape) + + mha_output, present_key, present_value = msft_op.MultiHeadAttention( + q_rope, + k_rope, + v, + None, + None, + expanded_mask, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(mha_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + mha + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + @_version_controller.register_version(min_version="4.36", max_version="4.38") + def _fusion_with_2d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + if len(function.input) != 9: # noqa: PLR2004 + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + expand_shape = [1, attn_size_config.num_attention_heads, 1, 1] + + def mha( + hidden_states, + position_id, + attention_mask, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding( + q, position_id, cos, sin, interleaved=False + ) + k_rope = msft_op.RotaryEmbedding( + k, position_id, cos, sin, interleaved=False + ) + + # TODO(onnxscript) + # ValueError: ERROR: Unsupported expression type . + # expanded_mask = op.Expand(attention_mask, [1, self.num_heads, 1, 1]) + expanded_mask = op.Expand(attention_mask, expand_shape) + + mha_output, present_key, present_value = msft_op.MultiHeadAttention( + q_rope, + k_rope, + v, + None, + None, + expanded_mask, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(mha_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + mha + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + +class GQALlama2RewriteRule(AttentionRewriteRule): + FUNCTION_KEYWORD = "LlamaAttention" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version(min_version="4.33", max_version="4.36") + def _fusion_with_4d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + if len(function.input) != 9: # noqa: PLR2004 + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + attention_mask, # noqa: ARG001 + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + # NOTE: Depending on transformers version, the shape of cos/sin is different. + # In later version, the shape is [seq_len, head_size], so the Squeeze is not needed. + # In this version, the shape is [1, 1, seq_len, head_size], hence the below Squeeze. + cos = op.Slice( + op.Squeeze(cos_cached, [0, 1]), [0], cos_sin_gather_size, [1] + ) + sin = op.Slice( + op.Squeeze(sin_cached, [0, 1]), [0], cos_sin_gather_size, [1] + ) + + q_rope = msft_op.RotaryEmbedding( + q, position_id, cos, sin, interleaved=False + ) + k_rope = msft_op.RotaryEmbedding( + k, position_id, cos, sin, interleaved=False + ) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + gqa + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + @_version_controller.register_version(min_version="4.36", max_version="4.38") + def _fusion_with_2d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + if len(function.input) != 9: # noqa: PLR2004 + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + attention_mask, # noqa: ARG001 + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding( + q, position_id, cos, sin, interleaved=False + ) + k_rope = msft_op.RotaryEmbedding( + k, position_id, cos, sin, interleaved=False + ) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + gqa + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + +class GQALlamaSdpa2RewriteRule(AttentionRewriteRule): + # TODO: There are a lot of duplicated code with `MHALlama2RewriteRule`. + # The pitfall is that the source function signature is slightly different. + # One has `attention_mask` as input while the other does not. + # Possibly designing a function template system could help reduce the boilerplate. + FUNCTION_KEYWORD = "LlamaSdpaAttention" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version(min_version="4.36", max_version="4.38") + def _fusion( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding( + q, position_id, cos, sin, interleaved=False + ) + k_rope = msft_op.RotaryEmbedding( + k, position_id, cos, sin, interleaved=False + ) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + gqa, + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + @_version_controller.register_version(min_version="4.38") + def _fusion_without_cos_sin_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + causal_mask, # noqa: ARG001 + cache_position, # noqa: ARG001 + q_proj_weight, + k_proj_weight, + v_proj_weight, + inv_freq, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + # In 4.38 and later, cos/sin are not cached, but computed on the fly. + # This can be further optimized by constant folding for scenarios where + # the position_id is known at compile time. + seq_len = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + seq_len_scalar = op.Squeeze(seq_len, [0]) + t = op.Unsqueeze( + op.Cast(op.Range(0, seq_len_scalar, 1), to=onnx.TensorProto.FLOAT), [1] + ) + inv_freq = op.Cast(op.Unsqueeze(inv_freq, [0]), to=onnx.TensorProto.FLOAT) + freqs = op.MatMul(t, inv_freq) + + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.CastLike(op.Cos(emb), hidden_states) + sin = op.CastLike(op.Sin(emb), hidden_states) + cos = op.Slice(cos, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding( + q, position_id, cos, sin, interleaved=False + ) + k_rope = msft_op.RotaryEmbedding( + k, position_id, cos, sin, interleaved=False + ) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + gqa, + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + +class AttnPhi15RewriteRule(AttentionRewriteRule): + FUNCTION_KEYWORD = "PhiAttention" + PACKAGE_NAME = "transformers_modules" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version() + def _fusion( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_opset = onnxscript.values.Opset("com.microsoft", 1) + + def phi_attention( + hidden_states, + position_id, # noqa: ARG001 + attention_mask, + q_proj_weight, + q_proj_bias, + k_proj_weight, + k_proj_bias, + v_proj_weight, + v_proj_bias, + cos_cached, # noqa: ARG001 + sin_cached, # noqa: ARG001 + dense_weight, + dense_bias, + ): + qkv_weight = op.Transpose( + op.Concat(q_proj_weight, k_proj_weight, v_proj_weight, axis=0), + perm=[1, 0], + ) + qkv_bias = op.Concat(q_proj_bias, k_proj_bias, v_proj_bias, axis=0) + + # [batch_size, sequence_length] + attention_mask_shape = op.Slice(op.Shape(hidden_states), [0], [2], [0]) + + # Create 2d mask to mimic 4d causal mask. + attention_mask = op.ConstantOfShape( + attention_mask_shape, + value=onnx_helper.make_tensor( + "mask_value", onnx.TensorProto.INT32, [1], [1] + ), + ) + attn_output, present = msft_opset.Attention( + hidden_states, + qkv_weight, + qkv_bias, + attention_mask, + unidirectional=1, + do_rotary=1, + # Attention.rotary_embedding_dim only supports 32, 64 or 128 + rotary_embedding_dim=attn_size_config.head_size // 2 // 32 * 32, + num_heads=attn_size_config.num_attention_heads, + ) + present_key = op.Gather(present, 0) + present_value = op.Gather(present, 1) + output = op.Add( + op.MatMul(attn_output, op.Transpose(dense_weight, [1, 0])), dense_bias + ) + + return present_value, present_key, output + + return onnxscript.script(default_opset=onnxscript.opset18)( + phi_attention + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) diff --git a/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/multihead_attention_test.py b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/multihead_attention_test.py new file mode 100644 index 0000000000..e4fe07423c --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/onnxruntime/transformers/multihead_attention_test.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from tests import common + + +class MHAParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + @common.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_attn_llama2_4_34(self): + common.test_onnxruntime_rewrite( + "attn_llama2_4_34", 2, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @common.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_attn_llama2_4_36(self): + common.test_onnxruntime_rewrite( + "attn_llama2_4_36", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @common.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_attn_yi_4_37(self): + common.test_onnxruntime_rewrite( + "attn_yi_4_37", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @common.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_sdpa_llama2_4_36(self): + # TODO: Clean-up naming logic of test models. + # Package version was not considered. + common.test_onnxruntime_rewrite( + "sdpa_llama2", 4, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @unittest.skip("TODO: Fails parity check") + def test_sdpa_llama2_4_38(self): + common.test_onnxruntime_rewrite( + "sdpa_llama2_4_38", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @common.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_sdpa_yi_4_36(self): + common.test_onnxruntime_rewrite( + "sdpa_yi", 2, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @unittest.skip("TODO: Fails parity check") + def test_sdpa_yi_4_38(self): + common.test_onnxruntime_rewrite( + "sdpa_yi_4_38", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + +class AttnParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + @common.skip_if_no_cuda("CPU has parity issue.") + def test_attn_phi_1_5(self): + common.test_onnxruntime_rewrite( + "attn_phi_1_5", 4, {("com.microsoft", "Attention", "")} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/rewriter/pattern.py b/onnxscript/onnxrewriter/rewriter/pattern.py new file mode 100644 index 0000000000..f93e2e4172 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/pattern.py @@ -0,0 +1,957 @@ +from __future__ import annotations + +import inspect +import itertools +import math +from typing import Any, Callable, Sequence + +import numpy as np +import onnx +import onnx.numpy_helper +import onnx.printer + +from onnxrewriter import ir +from onnxrewriter.ir import irbuilder + +# Overview of the pattern module: The classes below are used to define both +# patterns (that we search for) and replacements for rewrite rules. +# The matches() method of a pattern is used to check if an IR component +# matches the pattern. +# The to_ir() method of a pattern is used to create a new IR component +# TODO: Ensure that all matches() methods have same type signature (where +# appropriate) and that all to_ir() methods have same type signature (where +# appropriate). + + +class ConstantPattern: + def __init__(self, value: int | str | list) -> None: + self._value = value + + @property + def value(self) -> int | str | list: + return self._value + + def matches(self, value: int | str | list) -> bool: + return value == self.value + + def to_ir(self, model, bindings=None) -> int | str | list: # noqa: ARG002 + return self.value + + +class FloatConstantPattern: + def __init__( + self, value: float, rel_tol: float = 1e-5, abs_tol: float = 1e-8 + ) -> None: + self._value = value + self._rel_tol = rel_tol + self._abs_tol = abs_tol + + @property + def value(self): + return self._value + + def matches(self, value: float): + return math.isclose( + value, self.value, rel_tol=self._rel_tol, abs_tol=self._abs_tol + ) + + def to_ir(self, model, bindings=None) -> float: # noqa: ARG002 + return self.value + + +class TensorConstantPattern: + def __init__( + self, value: np.ndarray, rel_tol: float = 1e-3, abs_tol: float = 1e-3 + ) -> None: + self._value = value + self._rel_tol = rel_tol + self._abs_tol = abs_tol + + @property + def value(self): + return self._value + + def matches(self, value: np.ndarray): + return ( + value.dtype == self._value.dtype + and value.shape == self._value.shape + and np.allclose( + value, + self._value, + rtol=self._rel_tol, + atol=self._abs_tol, + ) + ) + + def to_ir(self, model, bindings=None) -> onnx.TensorProto: # noqa: ARG002 + return onnx.helper.make_tensor( + "", + onnx.helper.np_dtype_to_tensor_dtype(self.value.dtype), + self.value.shape, + self.value, + ) + + +def _make_constant_pattern( + value: float | int | list | np.ndarray, +) -> ConstantPattern | FloatConstantPattern | TensorConstantPattern: + """Convert an attrbute value to a ConstantPattern.""" + if isinstance(value, float): + return FloatConstantPattern(value) + if isinstance(value, (int, list)): + return ConstantPattern(value) + if isinstance(value, np.ndarray): + return TensorConstantPattern(value) + raise TypeError(f"Cannot convert {type(value)} to ConstantPattern") + + +class AnyPattern: + def matches(self, value) -> bool: # noqa: ARG002 + return True + + +class AttrPattern: + def __init__(self, value: Var | int | float | list | np.ndarray) -> None: + if isinstance(value, Var): + self.value_pattern = value + elif isinstance(value, (int, float, list, np.ndarray)): + self.value_pattern = _make_constant_pattern(value) + else: + raise TypeError(f"Cannot convert {type(value)} to AttrPattern") + + def matches(self, attr_val: int | float | list, model: ir.Model) -> MatchResult: + if isinstance(self.value_pattern, Var): + return self.value_pattern.matches(attr_val, model) + return self.value_pattern.matches(attr_val) + + def to_ir( + self, model: ir.Model, rewrite_cache: RewriteCache, bindings=None + ) -> ir.Val: + if isinstance(self.value_pattern, Var): + val, nodes = self.value_pattern.to_ir( + model, bindings, 1, rewrite_cache + ) # TODO: handle multiple outputs + return val + # constant pattern + return self.value_pattern.to_ir(model, bindings) + + +class OpsetPattern: + """Represents an opset pattern. + + It is used primarily to create a NodePattern (via OpPattern). + Example usage: + :: + + z = op.Matmul(x, y) + + Here, `op` is an instance of OpsetPattern and `op.Matmul` is an instance + of OpPattern, and `op.Matmul(x, y)` is an instance of NodePattern. + + An opset pattern is also matched against the actual opset used in the + input model. Typically, we match against an ONNX opset (ignoring the + version), but we can match against a specific version of the opset too. + However, it is preferable that version-dependences are handled at the + level of a rewrite rule, rather than at the level of a pattern. + + """ + + def __init__( + self, + domain_pattern: ConstantPattern, + version_pattern: ConstantPattern | AnyPattern, + ) -> None: + self.domain_pattern = domain_pattern + self.version_pattern = version_pattern + + @classmethod + def singleton(cls, domain: str, version: int): + return cls(ConstantPattern(domain), ConstantPattern(version)) + + @classmethod + def domain(cls, domain: str) -> OpsetPattern: + return cls(ConstantPattern(domain), AnyPattern()) + + def matches(self, opset): + domain, version = opset + return self.domain_pattern.matches(domain) and self.version_pattern.matches( + version + ) + + def to_ir(self, model, bindings=None) -> str: + domain = self.domain_pattern.to_ir(model, bindings) + # TODO: Should we ban other custom domains? + if domain not in model.version_map: + model.version_map[self.domain_pattern.value] = self.version_pattern.value + return domain + + def __getattr__(self, name: str) -> Any: + return OpPattern(self, ConstantPattern(name)) + + +opset17 = OpsetPattern.singleton("", 17) + +onnxop = OpsetPattern.domain("") + +msft_op = OpsetPattern.singleton("com.microsoft", 1) + + +class OpPattern: + """A utility class to build a NodePattern. + + It is used primarily to create a NodePattern. + Example usage: + :: + + z = op.Matmul(x, y) + + Here, `op` is an instance of OpsetPattern and `op.Matmul` is an instance + of OpPattern, and `op.Matmul(x, y)` is an instance of NodePattern. + + """ + + def __init__( + self, opset_pattern: OpsetPattern, op_name_pattern: ConstantPattern + ) -> None: + self.opset_pattern = opset_pattern + self.op_name_pattern = op_name_pattern + + def __call__(self, *args, **kwargs): + if "_num_outputs" in kwargs: + num_outputs = kwargs["_num_outputs"] + del kwargs["_num_outputs"] + else: + num_outputs = 1 + attributes = {name: AttrPattern(value) for (name, value) in kwargs.items()} + node_pattern = NodePattern( + self.opset_pattern, self.op_name_pattern, args, attributes + ) + if num_outputs == 1: + return NodeOutputPattern(node_pattern, 0) + else: + return [NodeOutputPattern(node_pattern, i) for i in range(num_outputs)] + + +def _to_value_pattern(x: ValuePattern | int | float) -> ValuePattern: + """Promotes an input-value used to construct a NodePattern to a ValuePattern. + + Example usage: + :: + x = op.MatMul(a, b) + z = op.Add(x, 0) + + In this example, `a, `b`, and `x` are ValuePatterns used to construct a NodePattern. + `0` is a constant (int) value, and is automatically promoted to a ValuePattern. + + Note that this is a shorthand for creating a Constant pattern. The user can more + explicitly write this as: + :: + z = op.Add(x, op.Constant(0)) + """ + if isinstance(x, ValuePattern): + return x + if isinstance(x, (int, float, list)): + return Constant(x) + # TODO(titaiwang): Could this be wrapped Constant? + raise TypeError(f"Cannot convert {type(x)} to ValuePattern") + + +class MatchResult: + """Represents the result of a match operation. + + A match can either succeed or fail. + If it succeeds, it returns a list of IR values that matched the pattern + and a set of bindings for the variables in the pattern. + + Example: + :: + def pattern(x, shape1, shape2): + t1 = op.Reshape(x, shape1) + t2 = op.Reshape(t1, shape2) + return t2 + The above pattern matches a sequence of two Reshape ops. + The matched_values will contain the values representing the (output of) + the two Reshape ops, and the bindings will contain the values that + are bound to the variables `x`, `shape1`, and `shape2`. + """ + + def __init__( + self, matched_values=None, bindings: dict[str, ir.Value | Any] | None = None + ) -> None: + assert matched_values is None or isinstance(matched_values, list) + self.success: bool = matched_values is not None + # For a successful match, matched_values is a list of values that matched the pattern. + # These include the internal nodes of the pattern that were matched, but not + # the leaves (sub-trees) that match against the variables in the pattern. + # These represent the values that will be replaced by the replacement pattern. + self.matched_values: Sequence[Any] | None = matched_values + # For a successful match, bindings is a dictionary of mapping pattern-variable-names + # to values. + self.bindings: dict[str, Any] = bindings if bindings is not None else {} + + def __bool__(self): + return self.success + + @classmethod + def FAIL(cls): # noqa: N802 + return cls(None) + + @property + def values(self) -> Sequence[Any] | None: + return self.matched_values + + def fail(self): + self.success = False + self.matched_values = None + self.bindings = {} + + def extend(self, other: MatchResult | bool, model): + del model # Unused + if not self.success: + return + if not other: + self.fail() + return + if isinstance(other, bool): + return + for var, val in other.bindings.items(): + if var in self.bindings: + # TODO: handle attribute var bindings + if not self.bindings[var].is_same_as(val): + self.fail() + return + else: + self.bindings[var] = val + self.matched_values.extend(other.matched_values) + + +class ValuePattern: + """Base class for all patterns that match against IR values. + + This is used primarily to provide operator overloadings for arithmetic + operations, so that we can write patterns like `x + 1` and `1 + x`. + """ + + def __init__(self) -> None: + pass + + def __add__(self, other): + return onnxop.Add(self, other) + + def __radd__(self, other): + return onnxop.Add(other, self) + + def __sub__(self, other): + return onnxop.Sub(self, other) + + def __rsub__(self, other): + return onnxop.Sub(other, self) + + def __mul__(self, other): + return onnxop.Mul(self, other) + + def __rmul__(self, other): + return onnxop.Mul(other, self) + + def __truediv__(self, other): + return onnxop.Div(self, other) + + def __rtruediv__(self, other): + return onnxop.Div(other, self) + + def __pow__(self, other): + return onnxop.Pow(self, other) + + +# NOTE(bowbao): Based on reading code, this is (nearly) the only place where `model` is used +# for (nearly) all the functions that passes `model` around. It seems the goal is to be able +# create unique value names. +def _make_node( + model: ir.Model, + domain: str, + op: str, + input, + attributes, + num_outputs: int, +) -> tuple[list[ir.Value], ir.Node]: + inputnames = [x.name for x in input] + outputs = [model.make_new_name() for i in range(num_outputs)] + node = onnx.helper.make_node(op, inputnames, outputs, domain=domain, **attributes) + newnode = ir.Node(node) + newvalues = [ir.Value(v, newnode, i) for i, v in enumerate(outputs)] + newnode.inputs = input + newnode.outputs = newvalues + newnode.attributes = attributes # TODO + return newvalues, newnode + + +class NodePattern: + """Represents a pattern that matches against a Node. + + This differs from a NodeOutputPattern in that it matches against a node (which + may produce 1 or more outputs), whereas a NodeOutputPattern matches against + a specific output of a node. + """ + + def __init__( + self, + domain: OpsetPattern, + op: ConstantPattern, + inputs: Sequence[int | float | ValuePattern], + attributes: dict[str, AttrPattern], + ): + self.domain = domain + self.op = op + self.inputs = [_to_value_pattern(x) for x in inputs] + self.attributes = attributes + self.bound_value = None + + def matches(self, value: ir.Value, model: ir.Model): + if self.bound_value is not None: + # DAG-matching, not Tree-matching. + if self.bound_value.is_same_as(value): + return MatchResult([]) + else: + return MatchResult.FAIL() + node = value.def_node() + if node is None: + # Eg., value could be an input parameter, which will not match a value + # computed by the op in this pattern. + return MatchResult.FAIL() + return self.matches_node(node, model) + + def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: + """Examine if the IR node matches the self pattern.""" + if not self.domain.matches((node.domain, node.version)): + return MatchResult.FAIL() + if not self.op.matches(node.op_type): + return MatchResult.FAIL() + match = MatchResult([]) + # TODO: We should add filtered logging starting from here to emit why + # matching failed. This should cut a lot of noises compared to logging everything, + # because at least the starting node op_type is already matched. + for arg_value, previous_node_output_pattern in zip(node.inputs, self.inputs): + # previous_node_output_pattern could be a Var, if it's the original arg. + sub_match = previous_node_output_pattern.matches(arg_value, model) + match.extend(sub_match, model) + if not match: # If sub-match failed, + return match + # Sub-graphs not handled yet. + for name, attr_pattern in self.attributes.items(): + attr_value = node.get_attribute(name) + if attr_value is None: + return MatchResult.FAIL() + sub_match = attr_pattern.matches(attr_value, model) + if not sub_match: + return MatchResult.FAIL() + match.extend(sub_match, model) + for name in node.attributes: + # TODO: Support matching default values for attributes. + if name not in self.attributes: + return MatchResult.FAIL() + match.values.append(node) + return match + + def to_ir( + self, + model: ir.Model, + bindings: dict[str, ir.Value | Any], + num_outputs: int, + rewrite_cache: RewriteCache, + ) -> tuple[list[ir.Value], list[ir.Node]]: + domain = self.domain.to_ir(model) + op = self.op.to_ir(model) + inputs = [] + nodes = [] + for val_pattern in self.inputs: + if ( + value_and_node := rewrite_cache.get_node_output_pattern(val_pattern) + ) is not None: + val, n = value_and_node + else: + val, n = val_pattern.to_ir(model, bindings, 1, rewrite_cache) + rewrite_cache.set_node_output_pattern_with_ir(val_pattern, val, n) + nodes.extend(n) + # If one of the inputs was a the output of a previous node, + # unpack the new output ir value that is created for that node + if isinstance(val, list): + # TODO: Move implementation of output_index to NodeOutputPatter.to_ir + inputs.append(val[val_pattern.output_index]) + else: + inputs.append(val) + attributes = { + name: attr_pattern.to_ir(model, rewrite_cache, bindings) + for (name, attr_pattern) in self.attributes.items() + } + newvals, newnode = _make_node( + model, domain, op, inputs, attributes, num_outputs + ) + nodes.append(newnode) + return newvals, nodes + + def commute(self) -> list[ValuePattern]: + list_of_lists = [pattern.commute() for pattern in self.inputs] + + def enumerate_inputs(inputs, index): + if index >= len(inputs): + yield [] + else: + for pattern in inputs[index]: + for rest in enumerate_inputs(inputs, index + 1): + yield [pattern, *rest] + + inputs = list(enumerate_inputs(list_of_lists, 0)) + if self.domain.matches(("", None)) and ( + self.op.matches("Add") or self.op.matches("Mul") + ): + # TODO: handle cases where number of inputs is not 2. + swapped = [[x[1], x[0]] for x in inputs] + inputs.extend(swapped) + return [ + NodePattern(self.domain, self.op, input, self.attributes) + for input in inputs + ] + + +class NodeOutputPattern(ValuePattern): + """Represents a pattern that matches against a specific output of a Node. + + This is the primary pattern used to match against computed values, that + is values computed using a specific op. + """ + + def __init__(self, node_pattern: NodePattern, output_index: int) -> None: + self.node_pattern = node_pattern + self.output_index = output_index + + def matches(self, value: ir.Value, model: ir.Model): + """Match the StaticValueInfo from IR with the `matches_node()` in node pattern.""" + node = value.def_node() + if node is None: + return MatchResult.FAIL() + if value.def_index() != self.output_index: + return MatchResult.FAIL() + return self.node_pattern.matches_node(node, model) + + def to_ir( + self, + model: ir.Model, + bindings: dict[str, ir.Value | Any], + num_outputs: int, + rewrite_cache: RewriteCache, + ) -> tuple[list[ir.Value], list[ir.Node]]: + assert self.output_index == 0, "TODO: handle multiple outputs" + return self.node_pattern.to_ir(model, bindings, num_outputs, rewrite_cache) + + +class Var(ValuePattern): + """Represents a pattern variable.""" + + def __init__(self, name: str) -> None: + self.pattern_var_name = name + self.bound_value = None + + def matches(self, value: ir.Value, model: ir.Model): # noqa: ARG002 + return MatchResult([], {self.pattern_var_name: value}) + + def to_ir( + self, + model: ir.Model, + bindings: dict[str, ir.Value | Any], + num_outputs: int, + rewrite_cache: RewriteCache, + ) -> tuple[ir.Value, list[None]]: + del model # Unused + del num_outputs # Unused + del rewrite_cache # Unused + return bindings[self.pattern_var_name], [] + + def commute(self) -> list[ValuePattern]: + return [self] + + +class Constant(ValuePattern): + """Represents a pattern that matches against a scalar constant value.""" + + def __init__( + self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8 + ) -> None: + self.value = value + self.rel_tol = rel_tol + self.abs_tol = abs_tol + + def match_scalar(self, scalar_value, return_value: list[ir.Node]): + if math.isclose( + scalar_value, self.value, rel_tol=self.rel_tol, abs_tol=self.abs_tol + ): + return MatchResult(return_value) + else: + return MatchResult.FAIL() + + def matches(self, value: ir.Value, model: ir.Model): + del model # Unused + constant_value = value.value_as_np_array + if isinstance(constant_value, np.ndarray): + # TODO (rama): allow users to specify shape requirement, if desired. + if constant_value.size != 1: + return MatchResult.FAIL() + + return_value = [] + # Note: If the value is produced by a Constant node, we could include + # the Constant node in the return_value list. However, we don't do that. + # Instead, we will rely on DCE to remove the constant node if it is not + # used elsewhere. + + return self.match_scalar(constant_value.item(), return_value) + return MatchResult.FAIL() + + def commute(self) -> list[ValuePattern]: + return [self] + + +def _handle_pattern_return_value( + node_output_pattern: NodeOutputPattern | list[NodeOutputPattern], +) -> tuple[NodePattern, int]: + """This checks and cleans up the return value of a pattern-construction function. + + A pattern-construction function will return values as below: + :: + def pattern(x, shape1, shape2): + ... + return op.SomeOp(...) + However, `SomeOp` may represent an ONNX op that produces multiple outputs. + This function validates that the return values represent the outputs of + a single NodePattern. It returns the node_pattern and the number of outputs. + + This follows an important restriction of the pattern-matcher algorithm: it + only matches against subgraphs that end in a single terminal node. If we + permit two terminal nodes, then we would have to match against all possible + pairs of nodes in the graph, which produces an extra quadratic factor in the + complexity of the pattern-matching algorithm. In general, the complexity becomes + exponential in the number of terminal nodes. + + Args: + node_output_pattern: NodeOutputPattern | list[NodeOutputPattern] + + Returns: + tuple[NodePattern, int]: The last node_pattern, num_outputs + """ + if isinstance(node_output_pattern, NodeOutputPattern): + node_pattern = node_output_pattern.node_pattern + num_outputs = 1 + elif isinstance(node_output_pattern, (list, tuple)): + node_pattern = node_output_pattern[0].node_pattern + num_outputs = len(node_output_pattern) + for i, p in enumerate(node_output_pattern): + assert isinstance(p, NodeOutputPattern) + assert p.node_pattern is node_pattern + assert p.output_index == i + else: + raise TypeError(f"Invalid type {type(node_output_pattern)} for pattern") + return node_pattern, num_outputs + + +# Currently, the replacement graph function is the same as the pattern function. +# This may change in the future. +_handle_replacement_return_value = _handle_pattern_return_value + + +def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: + """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" + # * Must check that all values matched by pattern are used only by pattern, + # except for the value that is replaced. + # * Must ensure that replacement subgraph does not use any of the deleted + # (intermediate) values. (Not necessary for now. Guaranteed.) + deleted_nodes = matched_nodes[:-1] + for n in deleted_nodes: + for v in n.outputs: + if v.is_output: + # value is an output-value of the graph/function. + return False + for use in v.uses: + if use not in matched_nodes: + return False + return True + + +class TargetPatternFunction: + """The targeted pattern that will be replaced by the replacement pattern. + + Attributes: + function (Callable): The pattern function that will be matched against the IR. + """ + + def __init__(self, function: Callable) -> None: + self._function = function + + @property + def function(self) -> Callable: + return self._function + + def get_pattern(self, *vars: Sequence[Var]) -> tuple[NodePattern, int]: # noqa: A002 + node_output_pattern = self._function(*vars) + return _handle_pattern_return_value(node_output_pattern) + + +class ReplacementPatternFunction: + """The replacement pattern that will replace the targeted pattern. + + Attributes: + function (Callable): The replacement function that will be used to replace the matched pattern. + delay_run (bool): If True, the replacement function will not be run until the matched pattern is found. + This is useful when we want to extract certain metavalue from the matched pattern and use it in the + replacement pattern. + """ + + def __init__(self, function, *, delay_run: bool = False): + self._function = function + self._delay_run = delay_run + + @property + def function(self) -> Callable: + return self._function + + @property + def delay_run(self) -> bool: + return self._delay_run + + # TODO: How do we merge it with to_ir function? + def get_pattern( + self, + *vars: Sequence[Var], # noqa: A002 + match_bindings: dict[str, ir.Value | Any] | None = None, + ) -> tuple[NodePattern | None, int | None]: + if self._delay_run: + if match_bindings is None: + return None, None + node_output_pattern = self._function(*vars, match_bindings) + else: + node_output_pattern = self._function(*vars) + return _handle_pattern_return_value(node_output_pattern) + + +class RewriteCache: + def __init__(self): + self._node_output_pattern_to_ir: dict[ + NodeOutputPattern, tuple[ir.Value, ir.Node] + ] = dict() + + def get_node_output_pattern( + self, node_output_pattern: NodeOutputPattern + ) -> tuple[ir.Value, ir.Node] | None: + return self._node_output_pattern_to_ir.get(node_output_pattern, None) + + def set_node_output_pattern_with_ir( + self, node_output_pattern: NodeOutputPattern, value: ir.Value, node: ir.Node + ) -> bool: + self._node_output_pattern_to_ir[node_output_pattern] = (value, node) + + +class RewriteRule: + def __init__( + self, + target_pattern: TargetPatternFunction | Callable | None = None, + replacement_pattern: ReplacementPatternFunction | Callable | None = None, + condition_function: Callable | None = None, + ) -> None: + """Create a rewrite rule. + + Args: + target_pattern: The pattern function that will be + matched against the IR. + replacement_pattern: The replacement function that + will be used to replace the matched pattern. + condition_function: The condition function that + will be used to check if the pattern matches the IR with ir.Values + constraints in consideration. + + """ + if target_pattern is None: + # NOTE: commute() generated rules will have target_pattern as None + # ReplacementPatternFunction is still needed in try_rewrite + assert replacement_pattern is None + assert condition_function is None + self._replacement_pattern = ReplacementPatternFunction(replacement_pattern) + return + elif replacement_pattern is None: + raise ValueError( + "replacement_pattern must be provided if target_pattern is provided" + ) + # TODO: Do we want to tolerate Callable inputs? + if callable(target_pattern): + target_pattern = TargetPatternFunction(target_pattern) + if callable(replacement_pattern): + replacement_pattern = ReplacementPatternFunction(replacement_pattern) + + self._target_pattern = target_pattern + self._replacement_pattern = replacement_pattern + self._condition_function = condition_function + + _pattern_vars = inspect.signature(self._target_pattern.function).parameters + _replacement_vars = inspect.signature( + self._replacement_pattern.function + ).parameters + # TODO: accept _replacement_vars being subset of _pattern_vars? + assert len(_pattern_vars) == len(_replacement_vars) + + self._vars = [Var(v) for v in _pattern_vars] + # Get the last node pattern and number of outputs from the pattern function + self._target_node_pattern, self._target_num_outputs = ( + self._target_pattern.get_pattern(*self._vars) + ) + # NOTE: Return Nones if the replacement pattern is delayed running + self._replace_node_pattern, _replacement_num_outputs = ( + replacement_pattern.get_pattern(*self._vars) + ) + if _replacement_num_outputs is not None: + assert self._target_num_outputs == _replacement_num_outputs + + def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: + """Check if the node from IR matches the pattern.""" + if len(node.outputs) != self._target_num_outputs: + return MatchResult.FAIL() + match = self._target_node_pattern.matches_node(node, model) + if ( + self._condition_function is not None + and match + and not self._condition_function(match.bindings) + ): + return MatchResult.FAIL() + return match + + def try_rewrite( + self, model: ir.Model, node: ir.Node + ) -> tuple[list[ir.Node], list[ir.Node]] | None: + """If the node matches the pattern, then replace the node with the replacement pattern.""" + match = self.matches(node, model) + if match: + if _valid_to_replace(match.values): + # NOTE: delayed running as the replacement pattern needs bindings + if self._replacement_pattern.delay_run: + # bindings will be consumed by the replacement function + self._replace_node_pattern, _replacement_num_outputs = ( + self._replacement_pattern.get_pattern( + *self._vars[:-1], match_bindings=match.bindings + ) + ) + assert self._target_num_outputs == _replacement_num_outputs + rewrite_cache = RewriteCache() + _, _to_insert = self._replace_node_pattern.to_ir( + model, match.bindings, self._target_num_outputs, rewrite_cache + ) + + return (match.values, _to_insert) + return None + + def apply_to_model(self, model: ir.Model, *, commute: bool = False): + # TODO(titaiwang): Why do we need RewriteRuleSet? + return RewriteRuleSet([self], commute=commute).apply_to_model(model) + + def count_matches(self, model: ir.Model, *, commute: bool = False): + return RewriteRuleSet([self], commute=commute).count_matches(model) + + def commute(self) -> list[RewriteRule]: + def replace_pattern(new_pattern): + """Return a shallow copy of self with node_pattern replaced by new_pattern.""" + rule = RewriteRule() + rule._condition_function = self._condition_function + rule._target_node_pattern = new_pattern + rule._target_num_outputs = self._target_num_outputs + rule._replace_node_pattern = self._replace_node_pattern + return rule + + return [replace_pattern(p) for p in self._target_node_pattern.commute()] + + +def _apply_deltas( + graph_or_function: ir.Graph | ir.Function, + deltas: list[tuple[int, tuple[list[ir.Node], list[ir.Node]]]], +): + nodes = graph_or_function.nodes + for i, delta in reversed(deltas): + deleted_nodes, inserted_nodes = delta + # Replace deleted nodes with inserted nodes. + # However, we merge the last deleted node and last inserted node + # to avoid replacing the values produced by the last deleted node + # in all places where they are used. So, we reuse the output + # values from the last deleted node and replace the node itself + # TODO: simplify this + last_deleted = deleted_nodes[-1] + last_inserted = inserted_nodes[-1] + + assert len(last_deleted.outputs) == len(last_inserted.outputs) + del last_inserted.outputs[:] + for v in last_deleted.outputs: + v.node = last_inserted + last_inserted.outputs.append(v) + + del nodes[i] + for item in reversed(inserted_nodes): + nodes.insert(i, item) + + for _, delta in deltas: + deleted_nodes, inserted_nodes = delta + inserted_input_output = [] + for nd in inserted_nodes: + inserted_input_output += nd.inputs + nd.outputs + for n in deleted_nodes[0:-1]: + # Delete intermediary outputs from graph that are not used as + # outputs of the graph + for output in n.outputs: + if not output.is_output and output not in inserted_input_output: + graph_or_function.values.pop(output.name) + nodes.remove(n) + + +class RewriteRuleSet: + def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: + if commute: + rules = list( + itertools.chain.from_iterable([rule.commute() for rule in rules]) + ) + self.rules = rules + + def _apply_to_graph_or_function( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + ) -> int: + count = 0 + deltas = [] + for i, node in enumerate(graph_or_function.nodes): + for rule in self.rules: + delta = rule.try_rewrite(model, node) + if delta is not None: + deltas.append((i, delta)) + count += 1 + break + _apply_deltas(graph_or_function, deltas) + return count + + def apply_to_model(self, model: ir.Model) -> int: + assert isinstance(model, ir.Model) + count = self._apply_to_graph_or_function(model, model.graph) + for function in model.functions: + count += self._apply_to_graph_or_function(model, function) + return count + + def _count_matches_in_graph_or_function( + self, model: ir.Model, graph_or_funciton: ir.Graph | ir.Function + ) -> int: + count = 0 + for node in graph_or_funciton.nodes: + for rule in self.rules: + if rule.matches(node, model): + count += 1 + break + return count + + def count_matches(self, model: onnx.ModelProto | ir.Model): + if isinstance(model, onnx.ModelProto): + model = irbuilder.build_ir(model) + else: + assert isinstance(model, ir.Model) + count = self._count_matches_in_graph_or_function(model, model.graph) + for function in model.functions: + count += self._count_matches_in_graph_or_function(model, function) + return count diff --git a/onnxscript/onnxrewriter/rewriter/pattern_test.py b/onnxscript/onnxrewriter/rewriter/pattern_test.py new file mode 100644 index 0000000000..09e01981e5 --- /dev/null +++ b/onnxscript/onnxrewriter/rewriter/pattern_test.py @@ -0,0 +1,312 @@ +import logging +import unittest + +import numpy as np +import onnx.parser + +from onnxrewriter.ir import irbuilder, protobuilder +from onnxrewriter.rewriter import cast_constant_of_shape, pattern + +logger = logging.getLogger(__name__) +op = pattern.onnxop +msft_op = pattern.msft_op + + +class ReciprocalMulTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def reciprocal_mul_pattern(x, y): + return (1 / x) * y + + def div(x, y): + return y / x + + return pattern.RewriteRule(reciprocal_mul_pattern, div) + + def test_single_match(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + c1 = Constant() + t1 = Div(c1, x) + z1 = Mul(t1, y) + z = Identity(z1) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 3) + + def test_failed_match(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + c1 = Constant() + t1 = Div(c1, x) + z1 = Mul(t1, y) + z = Identity(z1) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_multiple_matches(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + # {c1, t1, z1} is a valid match + # {c2, t2, z2} is a valid match + # {c3, t3, z3} is a match, but cannot be replaced since t3 has other-uses. + c1 = Constant() + c2 = Constant() + t2 = Div(c2, y) + t1 = Div(c1, x) + z1 = Mul(t1, y) + z2 = Mul(t2, z1) + + c3 = Constant() + t3 = Div(c3, x) + z3 = Mul(t3, y) + reuse_t3 = Div(t3, x) + z = Add(z2, reuse_t3) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 2) + self.assertEqual(len(ir.graph.nodes), 9) + + +class FastGeluTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def fast_gelu_pattern1(x): + b = 0.044715 + c = 0.79788 + tanh = op.Tanh(c * (x + (x**3) * b)) + return (1.0 + tanh) * (0.5 * x) + + def fast_gelu(x): + return msft_op.FastGelu(x) + + return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu) + + def long_form_rule(self) -> pattern.RewriteRule: + def fast_gelu_pattern1_long(x): + three = pattern.Constant(3) + x_cube = op.Pow(x, three) + b = pattern.Constant(0.044715) + x_cube_mul_b = op.Mul(x_cube, b) # support OR op.Mul(B, x_cube) + sum_ = op.Add(x, x_cube_mul_b) + c = pattern.Constant(0.79788) + c_times_sum = op.Mul(c, sum_) + tanh = op.Tanh(c_times_sum) + one = pattern.Constant(1.0) + one_plus_tanh = op.Add(one, tanh) + half = pattern.Constant(0.5) + half_x = op.Mul(half, x) + return op.Mul(one_plus_tanh, half_x) + + def fast_gelu(x): + return msft_op.FastGelu(x) + + return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu) + + def _check(self, rule): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + three = Constant () + x_cube = Pow(x, three) + B = Constant () + x_cube_mul_B = Mul(x_cube, B) + sum = Add(x, x_cube_mul_B) + C = Constant () + C_times_sum = Mul(C, sum) + tanh = Tanh(C_times_sum) + one = Constant () + one_plus_tanh = Add(one, tanh) + half = Constant () + half_x = Mul(half, x) + z = Mul(one_plus_tanh, half_x) + } + """ + ) + ir = irbuilder.build_ir(model) + count = rule.apply_to_model(ir) + self.assertEqual(count, 1) + # 5 Constant nodes and 1 FastGelu node + self.assertEqual(len(ir.graph.nodes), 6) + + def test_short_rule(self): + self._check(self.rule()) + + def test_long_rule(self): + self._check(self.long_form_rule()) + + +class ConcatTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def concat_pattern(x, y, axis): + seq = op.SequenceConstruct(x, y) + return op.ConcatFromSequence(seq, axis=axis) + + def concat(x, y, axis): + return op.Concat(x, y, axis=axis) + + return pattern.RewriteRule(concat_pattern, concat) + + def test_concat(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[M] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + + def test_concat_in_function(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 1) + self.assertEqual(ir.functions[0].nodes[0].op_type, "Concat") + + +class RewriteRuleTest(unittest.TestCase): + def test_commute(self): + op = pattern.onnxop + + def add_0(x): + return x + 0 + + def identity(x): + return op.Identity(x) + + add_0_rule = pattern.RewriteRule(add_0, identity) + + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[M] z) + { + zero = Constant () + z = Add (zero, x) + } + """ + ) + ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet([add_0_rule], commute=True).apply_to_model(ir) + optimized_model = protobuilder.build_model_proto(ir) + self.assertEqual(count, 1) + nodes = optimized_model.graph.node + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[1].op_type, "Identity") + + def test_const_value(self): + op = pattern.onnxop + + def reshape(x, newshape): + return op.Reshape(x, newshape) + + def identity(x, newshape): + del newshape # Unused + return op.Identity(x) + + def _check_for_redundant_reshape(x, newshape): + oldshape = x.shape + if not isinstance(oldshape, list): + return False + newshape = newshape.value_as_np_array + if not isinstance(newshape, np.ndarray): + return False + newshape = newshape.tolist() + + if len(oldshape) != len(newshape): + return False + for d1, d2 in zip(oldshape, newshape): + if d1 != d2 and d2 != -1: # noqa: PLR1714 + return False + return True + + def check_for_redundant_reshape(bindings): + return _check_for_redundant_reshape(**bindings) + + rule = pattern.RewriteRule(reshape, identity, check_for_redundant_reshape) + + model = onnx.parser.parse_model( + """ + + agraph (float[10, 20, 30] x) => (float[10, 20, 30] z) + { + shape = Constant () + z = Reshape (x, shape) + } + """ + ) + ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet([rule]).apply_to_model(ir) + optimized_model = protobuilder.build_model_proto(ir) + self.assertEqual(count, 1) + nodes = optimized_model.graph.node + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[1].op_type, "Identity") + + def test_delayed_run_provides_correct_bindings_for_multiple_matches(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[2] input_x) => (float16[1, 4] output, float[1, 4] output2) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + constant2 = ConstantOfShape (input_x) + output2 = Cast (constant2) + } + """ + ) + ir = irbuilder.build_ir(model) + count = cast_constant_of_shape.rules.apply_to_model(ir) + self.assertEqual(count, 2) + self.assertEqual(len(ir.graph.nodes), 2) + self.assertEqual(ir.graph.nodes[0].attributes["value"].data_type, 10) + self.assertEqual(ir.graph.nodes[1].attributes["value"].data_type, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/onnxrewriter/testing/__init__.py b/onnxscript/onnxrewriter/testing/__init__.py new file mode 100644 index 0000000000..0a594f3121 --- /dev/null +++ b/onnxscript/onnxrewriter/testing/__init__.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +__all__ = ["assert_onnx_proto_equal"] + +import difflib +import typing +from typing import Any, Collection, Sequence + +import google.protobuf.message + +if typing.TYPE_CHECKING: + import onnx + + +def _opset_import_key(opset_import: onnx.OperatorSetIdProto) -> tuple[str, int]: + return (opset_import.domain, opset_import.version) + + +def _value_info_key(value_info: onnx.ValueInfoProto) -> str: + return value_info.name + + +def _function_key(function: onnx.FunctionProto) -> tuple[str, str, str]: + return (function.domain, function.name, function.overload) + + +def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]: + """Return a list of duplicated elements in a collection.""" + seen = set() + duplicates = [] + for x in with_duplicates: + if x in seen: + duplicates.append(x) + seen.add(x) + return duplicates + + +def assert_onnx_proto_equal( + a: google.protobuf.message.Message | Any, b: google.protobuf.message.Message | Any +) -> None: + """Assert that two ONNX protos are equal. + + Equality is defined as having the same fields with the same values. When + a field takes the default value, it is considered equal to the field + not being set. + + Sequential fields with name `opset_import`, `value_info`, and `functions` are + compared disregarding the order of their elements. + + Args: + a: The first ONNX proto. + b: The second ONNX proto. + """ + assert type(a) == type(b), f"Type not equal: {type(a)} != {type(b)}" + + a_fields = {field.name: value for field, value in a.ListFields()} + b_fields = {field.name: value for field, value in b.ListFields()} + all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys())) + for field in all_fields: + # Obtain the default value if the field is not set. This way we can compare the two fields. + a_value = getattr(a, field) + b_value = getattr(b, field) + if ( + isinstance(a_value, Sequence) + and isinstance(b_value, Sequence) + and not isinstance(a_value, (str, bytes)) + and not isinstance(b_value, (str, bytes)) + ): + # Check length first + a_keys: list[Any] = [] + b_keys: list[Any] = [] + if field == "opset_import": + a_value = sorted(a_value, key=_opset_import_key) + b_value = sorted(b_value, key=_opset_import_key) + a_keys = [_opset_import_key(opset_import) for opset_import in a_value] + b_keys = [_opset_import_key(opset_import) for opset_import in b_value] + elif field == "value_info": + a_value = sorted(a_value, key=_value_info_key) + b_value = sorted(b_value, key=_value_info_key) + a_keys = [_value_info_key(value_info) for value_info in a_value] + b_keys = [_value_info_key(value_info) for value_info in b_value] + elif field == "functions": + a_value = sorted(a_value, key=_function_key) + b_value = sorted(b_value, key=_function_key) + a_keys = [_function_key(functions) for functions in a_value] + b_keys = [_function_key(functions) for functions in b_value] + + if a_keys != b_keys: + keys_only_in_a = set(a_keys) - set(b_keys) + keys_only_in_b = set(b_keys) - set(a_keys) + error_message = ( + f"Field {field} not equal: keys_only_in_a={keys_only_in_a}, keys_only_in_b={keys_only_in_b}. " + f"Field type: {type(a_value)}. " + f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}" + ) + raise AssertionError(error_message) + elif len(a_value) != len(b_value): + error_message = ( + f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} " + f"Field type: {type(a_value)}" + ) + raise AssertionError(error_message) + # Check every element + for i in range(len(a_value)): + a_value_i = a_value[i] + b_value_i = b_value[i] + if isinstance( + a_value_i, google.protobuf.message.Message + ) and isinstance(b_value_i, google.protobuf.message.Message): + try: + assert_onnx_proto_equal(a_value_i, b_value_i) + except AssertionError as e: + error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}" + raise AssertionError(error_message) from e + elif a_value_i != b_value_i: + error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}" + for line in difflib.ndiff( + str(a_value_i).splitlines(), str(b_value_i).splitlines() + ): + error_message += "\n" + line + raise AssertionError(error_message) + elif isinstance(a_value, google.protobuf.message.Message) and isinstance( + b_value, google.protobuf.message.Message + ): + assert_onnx_proto_equal(a_value, b_value) + elif a_value != b_value: + error_message = ( + f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}" + ) + raise AssertionError(error_message) diff --git a/onnxscript/onnxrewriter/utils/__init__.py b/onnxscript/onnxrewriter/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onnxscript/onnxrewriter/utils/evaluation_utils.py b/onnxscript/onnxrewriter/utils/evaluation_utils.py new file mode 100644 index 0000000000..58cbc4f5b3 --- /dev/null +++ b/onnxscript/onnxrewriter/utils/evaluation_utils.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import pathlib + +import numpy as np +import onnx +from onnx import helper as onnx_helper + + +def load_test_data( + qual_model_dir: str, input_names: list[str] +) -> tuple[dict[str, np.ndarray], list[np.ndarray]]: + test_data_dir = pathlib.Path(qual_model_dir) / "test_data_set_0" + inputs = {} + expected_outputs = [] + for test_data in test_data_dir.glob("input_*.pb"): + idx = int(test_data.stem[len("input_") :]) + input_name = input_names[idx] + input_data = onnx.TensorProto() + with open(test_data, "rb") as f: + input_data.ParseFromString(f.read()) + inputs[input_name] = onnx.numpy_helper.to_array(input_data) + + output_file_paths = list(test_data_dir.glob("output_*.pb")) + expected_outputs = [None] * len(output_file_paths) + for test_data in test_data_dir.glob("output_*.pb"): + idx = int(test_data.stem[len("output_") :]) + output_data = onnx.TensorProto() + with open(test_data, "rb") as f: + output_data.ParseFromString(f.read()) + expected_outputs[idx] = onnx.numpy_helper.to_array(output_data) # type: ignore[call-overload] + + assert all(name in inputs for name in input_names), "Some inputs are missing." + assert not any( + output is None for output in expected_outputs + ), "Some outputs are missing." + + return inputs, expected_outputs # type: ignore[return-value] + + +def generate_random_input(model: onnx.ModelProto) -> dict[str, np.ndarray]: + """Generate random input for the model. + + NOTE: This is unused. There is parity issue with randomly generated data. Need investigation. + """ + inputs = {} + for _, input in enumerate(model.graph.input): + shape = [d.dim_value for d in input.type.tensor_type.shape.dim] + np_dtype = onnx_helper.tensor_dtype_to_np_dtype( + input.type.tensor_type.elem_type + ) + if np_dtype is None: + raise ValueError(f"Unsupported dtype: {input.type.tensor_type.elem_type}") + if np_dtype in (np.float16, np.float32, np.float64): + inputs[input.name] = np.random.rand(*shape).astype(np_dtype) - 0.5 + else: + inputs[input.name] = np.random.randint(3, 100, size=shape, dtype=np_dtype) + return inputs diff --git a/onnxscript/onnxrewriter/utils/timing_utils.py b/onnxscript/onnxrewriter/utils/timing_utils.py new file mode 100644 index 0000000000..4661e055f0 --- /dev/null +++ b/onnxscript/onnxrewriter/utils/timing_utils.py @@ -0,0 +1,33 @@ +import time + +import onnx + +from onnxrewriter import optimizer + +# from onnxrewriter.rewriter.rules import all_rules + + +def timeit(f, message): + def timed(*args, **kw): + ts = time.time() + result = f(*args, **kw) + te = time.time() + print(f"{message} time: {te-ts}") + return result + + return timed + + +load = timeit(onnx.load, "Load") + +save = timeit(onnx.save, "Save") + +infer = timeit(onnx.shape_inference.infer_shapes, "Infer") + +fold_constants = timeit(optimizer.fold_constants, "Fold Constants") + +remove_unused = timeit(optimizer.remove_unused_nodes, "Remove Unused") + +optimize = timeit(optimizer.optimize, "Optimize") + +# rewrite = timeit(all_rules.apply_to_model, "Rewrite") diff --git a/onnxscript/onnxrewriter/utils/utils.py b/onnxscript/onnxrewriter/utils/utils.py new file mode 100644 index 0000000000..26ef525b1c --- /dev/null +++ b/onnxscript/onnxrewriter/utils/utils.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Any + +import onnx + + +def normalize_domain(d: str) -> str: + return "" if d == "ai.onnx" else d + + +def is_onnx_domain(d: str) -> bool: + return normalize_domain(d) == "" + + +def is_onnx_op(node: onnx.NodeProto, op_type: str) -> bool: + return is_onnx_domain(node.domain) and node.op_type == op_type + + +def is_control_flow_op(node: onnx.NodeProto) -> bool: + return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute) + + +def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any: + matching = [x for x in node.attribute if x.name == attr_name] + if len(matching) > 1: + raise ValueError(f"Node has multiple attributes with name {attr_name}") + if len(matching) < 1: + return default + return onnx.helper.get_attribute_value(matching[0]) + + +def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto: + type = onnx.TypeProto() + type.tensor_type.elem_type = initializer.data_type + dims = type.tensor_type.shape.dim + for dim in initializer.dims: + dims.add().dim_value = dim + return type + + +def get_constant_node_value(node: onnx.NodeProto, name: str) -> onnx.TensorProto | None: + if ( + node.op_type != "Constant" + or node.domain not in {"", "ai.onnx"} + or len(node.attribute) != 1 + ): + return None + attr = node.attribute[0] + if attr.ref_attr_name: + return None + attr_name = attr.name + value = onnx.helper.get_attribute_value(attr) + + if isinstance(value, onnx.TensorProto): + # Two names exist in this case: we use tensorproto as is (with original name) + return value + shape: list[int] + if attr_name == "value_int": + dtype = onnx.TensorProto.INT64 + shape = [] + value = [value] + elif attr_name == "value_float": + dtype = onnx.TensorProto.FLOAT + shape = [] + value = [value] + elif attr_name == "value_string": + dtype = onnx.TensorProto.STRING + shape = [] + value = [value] + elif attr_name == "value_ints": + dtype = onnx.TensorProto.INT64 + shape = [len(value)] + elif attr_name == "value_floats": + dtype = onnx.TensorProto.FLOAT + shape = [len(value)] + elif attr_name == "value_strings": + dtype = onnx.TensorProto.STRING + shape = [len(value)] + else: + return None # sparse tensors not handled + return onnx.helper.make_tensor(name, dtype, shape, value)