Skip to content

Commit b90c1ad

Browse files
authored
Refine shape optimization (#2336)
Refine the recently introduced shape optimization: more patterns showed up in the openai whisper model, extracting different slices of the concatenated shape. The optimization improves MHA fusions (which are other handicapped by the reuse of some intermediate values that prevent fusion). --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 276bf27 commit b90c1ad

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

onnxscript/rewriter/_ir_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
6868
"""
6969
if val is None:
7070
return None
71-
const_value = val.const_value
71+
const_value = get_const_value(val)
7272
if const_value is not None:
7373
try:
7474
return const_value.numpy()

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
5353
shape_inference.infer_shapes(model)
5454
optimize(model)
5555
shape_optimization.rules.apply_to_model(model)
56+
optimize(model)
5657
return model
5758

5859

onnxscript/rewriter/ort_fusions/shape_optimization.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import onnxscript.ir as ir
9+
import onnxscript.rewriter._ir_utils as _ir_utils
910
import onnxscript.rewriter.pattern as pattern
1011

1112

@@ -17,15 +18,19 @@ def __init__(self):
1718
It can be simplified away.
1819
"""
1920

20-
def pattern(self, op, x, dim0, dim1, dim2, dim3):
21+
def pattern(self, op, x, dim0, dim1, dim2, dim3, start, end):
2122
shape = op.Concat(dim0, dim1, dim2, dim3, axis=0)
22-
reshaped = op.Reshape(x, shape, allowzero=0)
23+
# Note: The allowzero=1 attribute enables us to infer that the shape of the
24+
# reshaped tensor is the same as the value of the shape parameter below.
25+
# Otherwise, we need to know that there are no zeros in the value of "shape"
26+
# for this optimization to be valid.
27+
reshaped = op.Reshape(x, shape, allowzero=1)
2328
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3])
24-
final_shape = op.Shape(transposed, _outputs=["final_shape"], start=0)
25-
final_dim = op.Slice(final_shape, [-2], [-1])
29+
final_shape = op.Shape(transposed, _outputs=["final_shape"])
30+
final_dim = op.Slice(final_shape, start, end)
2631
return final_dim
2732

28-
def check(self, context, dim0, dim1, dim2, dim3, final_shape, **_) -> bool:
33+
def check(self, context, dim0, dim1, dim2, dim3, final_shape, start, end, **_) -> bool:
2934
# All of the dimensions should have shape [1]
3035
for dim in (dim0, dim1, dim2, dim3):
3136
if dim.shape is None or dim.shape.dims != (1,):
@@ -37,11 +42,22 @@ def check(self, context, dim0, dim1, dim2, dim3, final_shape, **_) -> bool:
3742
return False
3843
if "start" in shape_node.attributes:
3944
start_attr = shape_node.attributes["start"]
40-
return isinstance(start_attr, ir.Attr) and start_attr.value == 0
45+
if not (isinstance(start_attr, ir.Attr) and start_attr.value == 0):
46+
return False
47+
self._start_val = _ir_utils.get_singleton_value(start)
48+
self._end_val = _ir_utils.get_singleton_value(end)
49+
if self._start_val is None or self._end_val is None:
50+
return False
4151
return True
4252

43-
def rewrite(self, op, dim1, **_):
44-
return dim1
53+
def rewrite(self, op, dim0, dim1, dim2, dim3, **_):
54+
transposed_dims = [dim0, dim2, dim1, dim3]
55+
sliced_result = transposed_dims[self._start_val : self._end_val]
56+
if len(sliced_result) == 0:
57+
return op.Constant(value_ints=[])
58+
if len(sliced_result) == 1:
59+
return op.Identity(sliced_result[0])
60+
return op.Concat(*sliced_result, axis=0)
4561

4662

4763
rules = pattern.RewriteRuleSet([ExtractDim.rule()])
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import unittest
4+
5+
import numpy as np
6+
import onnx
7+
import parameterized
8+
9+
from onnxscript import FLOAT, INT64, ir, opset18, script
10+
from onnxscript.rewriter.ort_fusions import shape_optimization
11+
12+
13+
def _make_model(starts: list[int], ends: list[int]) -> onnx.ModelProto:
14+
@script()
15+
def model_script(
16+
x: FLOAT["N"], # noqa: F821
17+
dim0: INT64[1],
18+
dim1: INT64[1],
19+
dim2: INT64[1],
20+
dim3: INT64[1],
21+
) -> INT64["M"]: # noqa: F821
22+
shape = opset18.Concat(dim0, dim1, dim2, dim3, axis=0)
23+
reshaped = opset18.Reshape(x, shape, allowzero=1)
24+
transposed = opset18.Transpose(reshaped, perm=[0, 2, 1, 3])
25+
final_shape = opset18.Shape(transposed)
26+
final_dim = opset18.Slice(final_shape, starts, ends)
27+
return opset18.Add(final_dim, final_dim)
28+
29+
model_proto = model_script.to_model_proto()
30+
return model_proto
31+
32+
33+
# Example input data
34+
_model_inputs = {
35+
"x": np.zeros((24,), dtype=np.float32),
36+
"dim0": np.array([2], dtype=np.int64),
37+
"dim1": np.array([3], dtype=np.int64),
38+
"dim2": np.array([4], dtype=np.int64),
39+
"dim3": np.array([1], dtype=np.int64),
40+
}
41+
42+
43+
class ShapeOptimizationTest(unittest.TestCase):
44+
@parameterized.parameterized.expand(
45+
[
46+
([0], [1], "singleton"),
47+
([1], [3], "two_elements"),
48+
([1], [-1], "negative_index"),
49+
([-2], [1000], "out_of_bounds"),
50+
([-200], [-1], "negative_out_of_bounds"),
51+
([2], [2], "empty_slice"),
52+
]
53+
)
54+
def test_shape_optimization(self, starts: list[int], ends: list[int], _name: str):
55+
model_proto = _make_model(starts, ends)
56+
model = ir.serde.deserialize_model(model_proto)
57+
58+
count = shape_optimization.rules.apply_to_model(model)
59+
self.assertEqual(count, 1)
60+
optimized_proto = ir.serde.serialize_model(model)
61+
62+
import onnxruntime as ort
63+
64+
sess = ort.InferenceSession(
65+
model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
66+
)
67+
outputs = sess.run(None, _model_inputs)
68+
sess = ort.InferenceSession(
69+
optimized_proto.SerializeToString(), providers=["CPUExecutionProvider"]
70+
)
71+
optimized_outputs = sess.run(None, _model_inputs)
72+
for orig, opt in zip(outputs, optimized_outputs):
73+
np.testing.assert_array_equal(orig, opt)
74+
75+
76+
if __name__ == "__main__":
77+
unittest.main()

0 commit comments

Comments
 (0)