Skip to content

Optimize causal mask shape #2325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import onnxscript.ir as ir
import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization
from onnxscript.ir.passes.common import shape_inference
from onnxscript.optimizer import optimize
from onnxscript.rewriter import rewrite
Expand Down Expand Up @@ -51,6 +52,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
# incorporated in our optimizer.
shape_inference.infer_shapes(model)
optimize(model)
shape_optimization.rules.apply_to_model(model)
return model


Expand Down
47 changes: 47 additions & 0 deletions onnxscript/rewriter/ort_fusions/shape_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Optimization for shape operations."""

from __future__ import annotations

import onnxscript.ir as ir
import onnxscript.rewriter.pattern as pattern


class ExtractDim(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__(remove_nodes=False)

"""This is a pattern observed in causal mask generation that hinders fusion optimizations.
It can be simplified away.
"""

def pattern(self, op, x, dim0, dim1, dim2, dim3):
shape = op.Concat(dim0, dim1, dim2, dim3, axis=0)
reshaped = op.Reshape(x, shape, allowzero=0)
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3])
final_shape = op.Shape(transposed, _outputs=["final_shape"], start=0)
final_dim = op.Slice(final_shape, [-2], [-1])
return final_dim

def check(self, context, dim0, dim1, dim2, dim3, final_shape, **_) -> bool:
# All of the dimensions should have shape [1]
for dim in (dim0, dim1, dim2, dim3):
if dim.shape is None or dim.shape.dims != (1,):
return False

Check warning on line 32 in onnxscript/rewriter/ort_fusions/shape_optimization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/shape_optimization.py#L32

Added line #L32 was not covered by tests

# The Shape op should return the full shape, not a slice of the shape.
shape_node = final_shape.producer()

Check warning on line 35 in onnxscript/rewriter/ort_fusions/shape_optimization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/shape_optimization.py#L35

Added line #L35 was not covered by tests
if "end" in shape_node.attributes:
return False

Check warning on line 37 in onnxscript/rewriter/ort_fusions/shape_optimization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/shape_optimization.py#L37

Added line #L37 was not covered by tests
if "start" in shape_node.attributes:
start_attr = shape_node.attributes["start"]
return isinstance(start_attr, ir.Attr) and start_attr.value == 0
return True

Check warning on line 41 in onnxscript/rewriter/ort_fusions/shape_optimization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/shape_optimization.py#L39-L41

Added lines #L39 - L41 were not covered by tests

def rewrite(self, op, dim1, **_):
return dim1

Check warning on line 44 in onnxscript/rewriter/ort_fusions/shape_optimization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/shape_optimization.py#L44

Added line #L44 was not covered by tests


rules = pattern.RewriteRuleSet([ExtractDim.rule()])
Loading