Skip to content

Commit b34cd9c

Browse files
authored
Optimize causal mask shape (#2325)
The generation of the causal mask's shape (produced by the translation of scalar_dot_product_attention) interferes with the subsequent fusion optimizations (because it makes use of the shape of the intermediate matmul value). This PR introduces a very specific fusion/rewrite to eliminate this redundant computation of the "sequence length" dimension. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 8540282 commit b34cd9c

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import onnxscript.ir as ir
6+
import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization
67
from onnxscript.ir.passes.common import shape_inference
78
from onnxscript.optimizer import optimize
89
from onnxscript.rewriter import rewrite
@@ -51,6 +52,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
5152
# incorporated in our optimizer.
5253
shape_inference.infer_shapes(model)
5354
optimize(model)
55+
shape_optimization.rules.apply_to_model(model)
5456
return model
5557

5658

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Optimization for shape operations."""
5+
6+
from __future__ import annotations
7+
8+
import onnxscript.ir as ir
9+
import onnxscript.rewriter.pattern as pattern
10+
11+
12+
class ExtractDim(pattern.RewriteRuleClassBase):
13+
def __init__(self):
14+
super().__init__(remove_nodes=False)
15+
16+
"""This is a pattern observed in causal mask generation that hinders fusion optimizations.
17+
It can be simplified away.
18+
"""
19+
20+
def pattern(self, op, x, dim0, dim1, dim2, dim3):
21+
shape = op.Concat(dim0, dim1, dim2, dim3, axis=0)
22+
reshaped = op.Reshape(x, shape, allowzero=0)
23+
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3])
24+
final_shape = op.Shape(transposed, _outputs=["final_shape"], start=0)
25+
final_dim = op.Slice(final_shape, [-2], [-1])
26+
return final_dim
27+
28+
def check(self, context, dim0, dim1, dim2, dim3, final_shape, **_) -> bool:
29+
# All of the dimensions should have shape [1]
30+
for dim in (dim0, dim1, dim2, dim3):
31+
if dim.shape is None or dim.shape.dims != (1,):
32+
return False
33+
34+
# The Shape op should return the full shape, not a slice of the shape.
35+
shape_node = final_shape.producer()
36+
if "end" in shape_node.attributes:
37+
return False
38+
if "start" in shape_node.attributes:
39+
start_attr = shape_node.attributes["start"]
40+
return isinstance(start_attr, ir.Attr) and start_attr.value == 0
41+
return True
42+
43+
def rewrite(self, op, dim1, **_):
44+
return dim1
45+
46+
47+
rules = pattern.RewriteRuleSet([ExtractDim.rule()])

0 commit comments

Comments
 (0)