diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index d6c4177ae8..6af84dd1d8 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -68,7 +68,7 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: """ if val is None: return None - const_value = val.const_value + const_value = get_const_value(val) if const_value is not None: try: return const_value.numpy() diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 79de57f335..c0d07183cd 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -53,6 +53,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model: shape_inference.infer_shapes(model) optimize(model) shape_optimization.rules.apply_to_model(model) + optimize(model) return model diff --git a/onnxscript/rewriter/ort_fusions/shape_optimization.py b/onnxscript/rewriter/ort_fusions/shape_optimization.py index d8399b7293..c4e34b42af 100644 --- a/onnxscript/rewriter/ort_fusions/shape_optimization.py +++ b/onnxscript/rewriter/ort_fusions/shape_optimization.py @@ -6,6 +6,7 @@ from __future__ import annotations import onnxscript.ir as ir +import onnxscript.rewriter._ir_utils as _ir_utils import onnxscript.rewriter.pattern as pattern @@ -17,15 +18,19 @@ def __init__(self): It can be simplified away. """ - def pattern(self, op, x, dim0, dim1, dim2, dim3): + def pattern(self, op, x, dim0, dim1, dim2, dim3, start, end): shape = op.Concat(dim0, dim1, dim2, dim3, axis=0) - reshaped = op.Reshape(x, shape, allowzero=0) + # Note: The allowzero=1 attribute enables us to infer that the shape of the + # reshaped tensor is the same as the value of the shape parameter below. + # Otherwise, we need to know that there are no zeros in the value of "shape" + # for this optimization to be valid. + reshaped = op.Reshape(x, shape, allowzero=1) transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) - final_shape = op.Shape(transposed, _outputs=["final_shape"], start=0) - final_dim = op.Slice(final_shape, [-2], [-1]) + final_shape = op.Shape(transposed, _outputs=["final_shape"]) + final_dim = op.Slice(final_shape, start, end) return final_dim - def check(self, context, dim0, dim1, dim2, dim3, final_shape, **_) -> bool: + def check(self, context, dim0, dim1, dim2, dim3, final_shape, start, end, **_) -> bool: # All of the dimensions should have shape [1] for dim in (dim0, dim1, dim2, dim3): if dim.shape is None or dim.shape.dims != (1,): @@ -37,11 +42,22 @@ def check(self, context, dim0, dim1, dim2, dim3, final_shape, **_) -> bool: return False if "start" in shape_node.attributes: start_attr = shape_node.attributes["start"] - return isinstance(start_attr, ir.Attr) and start_attr.value == 0 + if not (isinstance(start_attr, ir.Attr) and start_attr.value == 0): + return False + self._start_val = _ir_utils.get_singleton_value(start) + self._end_val = _ir_utils.get_singleton_value(end) + if self._start_val is None or self._end_val is None: + return False return True - def rewrite(self, op, dim1, **_): - return dim1 + def rewrite(self, op, dim0, dim1, dim2, dim3, **_): + transposed_dims = [dim0, dim2, dim1, dim3] + sliced_result = transposed_dims[self._start_val : self._end_val] + if len(sliced_result) == 0: + return op.Constant(value_ints=[]) + if len(sliced_result) == 1: + return op.Identity(sliced_result[0]) + return op.Concat(*sliced_result, axis=0) rules = pattern.RewriteRuleSet([ExtractDim.rule()]) diff --git a/onnxscript/rewriter/ort_fusions/shape_optimization_test.py b/onnxscript/rewriter/ort_fusions/shape_optimization_test.py new file mode 100644 index 0000000000..f563ef58d5 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/shape_optimization_test.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import parameterized + +from onnxscript import FLOAT, INT64, ir, opset18, script +from onnxscript.rewriter.ort_fusions import shape_optimization + + +def _make_model(starts: list[int], ends: list[int]) -> onnx.ModelProto: + @script() + def model_script( + x: FLOAT["N"], # noqa: F821 + dim0: INT64[1], + dim1: INT64[1], + dim2: INT64[1], + dim3: INT64[1], + ) -> INT64["M"]: # noqa: F821 + shape = opset18.Concat(dim0, dim1, dim2, dim3, axis=0) + reshaped = opset18.Reshape(x, shape, allowzero=1) + transposed = opset18.Transpose(reshaped, perm=[0, 2, 1, 3]) + final_shape = opset18.Shape(transposed) + final_dim = opset18.Slice(final_shape, starts, ends) + return opset18.Add(final_dim, final_dim) + + model_proto = model_script.to_model_proto() + return model_proto + + +# Example input data +_model_inputs = { + "x": np.zeros((24,), dtype=np.float32), + "dim0": np.array([2], dtype=np.int64), + "dim1": np.array([3], dtype=np.int64), + "dim2": np.array([4], dtype=np.int64), + "dim3": np.array([1], dtype=np.int64), +} + + +class ShapeOptimizationTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ([0], [1], "singleton"), + ([1], [3], "two_elements"), + ([1], [-1], "negative_index"), + ([-2], [1000], "out_of_bounds"), + ([-200], [-1], "negative_out_of_bounds"), + ([2], [2], "empty_slice"), + ] + ) + def test_shape_optimization(self, starts: list[int], ends: list[int], _name: str): + model_proto = _make_model(starts, ends) + model = ir.serde.deserialize_model(model_proto) + + count = shape_optimization.rules.apply_to_model(model) + self.assertEqual(count, 1) + optimized_proto = ir.serde.serialize_model(model) + + import onnxruntime as ort + + sess = ort.InferenceSession( + model_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + outputs = sess.run(None, _model_inputs) + sess = ort.InferenceSession( + optimized_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + optimized_outputs = sess.run(None, _model_inputs) + for orig, opt in zip(outputs, optimized_outputs): + np.testing.assert_array_equal(orig, opt) + + +if __name__ == "__main__": + unittest.main()