Skip to content

Fusion for partial rotary embedding #2095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,15 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
return None


def get_singleton_value(val: ir.Value | None):
"""Returns element of a single element tensor constant value, and None otherwise."""
def get_singleton_value(val: ir.Value | None, rank: int | None = None):
"""Returns element of a single element tensor constant value, and None otherwise.

If rank is specified, it checks that the value has the given rank.
"""
np_val = get_numpy_value(val)
if np_val is not None and np_val.size == 1:
return np_val.item()
if rank is None or (np_val.ndim == rank):
return np_val.item()
return None


Expand Down
67 changes: 67 additions & 0 deletions onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import FLOAT, INT64

# A simple rotary embedding example


# x: [B, H, S, E]
# position_ids: [B, S]
Expand Down Expand Up @@ -57,6 +59,7 @@ def test_case_1():
return _TestCase1()


# A simple rotary embedding example with 1D position_ids
# x: [B, H, S, E]
# position_ids: [S]
@script()
Expand Down Expand Up @@ -101,3 +104,67 @@ def get_ort_inputs(self):

def test_case_2():
return _TestCase2()


# A partial rotary embedding example:

rotary_embedding_dim = 32 # Abbreviated as "rd" in shape descriptors below
half_rotary_embedding_dim = rotary_embedding_dim // 2
# A random inverse frequency tensor for the sake of this example.
inv_freqs_value = numpy.random.rand(1, half_rotary_embedding_dim, 1).astype(numpy.float32)


@script()
def _partial_rotary_script(position_ids, query):
inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1]
position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S]
position_ids_3d_float = op.Cast(position_ids_3d, to=1)
matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S]
transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2]
cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd]
cos_3d = op.Cos(cat) # [B, S, rd]
sin_3d = op.Sin(cat) # [B, S, rd]
# Split the query for partial embedding
to_embed = op.Slice(query, [0], [32], [3], [1])
unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1])
cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd]
sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd]
# Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X)
# essentially represents X rotated by 90 degrees
to_embed_times_cos = op.Mul(to_embed, cos_4d)
to_embed_x = op.Slice(to_embed, [0], [16], [3], [1])
to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1])
minus_to_embed_y = op.Neg(to_embed_y)
to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1)
to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d)
embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin)
final = op.Concat(embedded, unembedded, axis=-1)
return final


class _PartialRotaryTestCase:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = _partial_rotary_script.to_model_proto(
input_types=(
INT64["Batchsize", "Sequence"],
FLOAT["Batchsize", 32, "Sequence", 80],
),
output_types=(FLOAT["Batchsize", 32, "Sequence", 80],),
)
model = ir.serde.deserialize_model(model_proto)
self._onnx_model = model
return self._onnx_model

def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
inputs = {
"query": numpy.random.rand(1, 32, 8, 80).astype(numpy.float32),
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8),
}
self._ort_inputs = inputs
return self._ort_inputs


def partial_rotary_test_case():
return _PartialRotaryTestCase()
13 changes: 9 additions & 4 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,21 @@ def rewrite(
_cast_const_freqs = CosSinCacheFusion.rule(
"CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True
)
_cast = CosSinCacheFusion.rule(
"CosSinCache_cast_no_const_freqs", 2048, cast=True, const_freqs=False
_cast = CosSinCacheFusion.rule("CosSinCache_cast", 2048, cast=True, const_freqs=False)
_const_freqs = CosSinCacheFusion.rule(
"CosSinCache_const_freqs", 2048, cast=False, const_freqs=True
)
_basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)

cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic])
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic])


def fuse_cos_sin_cache(model: ir.Model) -> int:
def fuse_cos_sin_cache(model: ir.Model, debug: bool = False) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
if count == 0 and debug:
tracer = pattern.MatchingTracer()
cos_sin_cache_rules.apply_to_model(model, tracer=tracer)
tracer.report()
if count != 0:
remove_unused_nodes(model)
return count
30 changes: 28 additions & 2 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
from parameterized import parameterized

import onnxscript.optimizer
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1, test_case_2
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import (
partial_rotary_test_case,
test_case_1,
test_case_2,
)
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
fuse_partial_rotary_embedding,
fuse_rotary_embedding,
)


class TestCosSinCacheTransform(unittest.TestCase):
Expand All @@ -29,6 +36,10 @@ class TestCosSinCacheTransform(unittest.TestCase):
"test_case_2",
test_case_2,
),
(
"partial_rotary_test_case",
partial_rotary_test_case,
),
]
)
def test_cos_sin_fusion(self, name, test_data_constructor):
Expand All @@ -44,6 +55,21 @@ def test_cos_sin_fusion(self, name, test_data_constructor):
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)

def test_partial_rotary_fusion(self):
test = partial_rotary_test_case()
model = test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
count = fuse_rotary_embedding(model)
self.assertGreater(count, 0)
count = fuse_cos_sin_cache(model)
self.assertGreater(count, 0)
count = fuse_partial_rotary_embedding(model)
self.assertGreater(count, 0)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions onnxscript/rewriter/ort_fusions/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,69 @@ def rewrite(self, op, x, cos, sin, **_):
)


class PartialRotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, end1, start2):
x_part_1 = op.Slice(x, [0], end1, [3], [1])
x_part_2 = op.Slice(x, start2, [9223372036854775807], [3], [1])
x_part_1_rope = op.RotaryEmbedding(
x_part_1,
_allow_other_inputs=True,
_allow_other_attributes=True,
_domain="com.microsoft",
_outputs=["x_part_1_rope"],
)
return op.Concat(x_part_1_rope, x_part_2, axis=-1)

def check(self, op, x, end1, start2, x_part_1_rope, **_):
end1_value = _ir_utils.get_singleton_value(end1)
start2_value = _ir_utils.get_singleton_value(start2)
if not isinstance(end1_value, int) or not isinstance(start2_value, int):
return False
if end1_value != start2_value:
return False
rotary_embedding_attributes = x_part_1_rope.producer().attributes
if "rotary_embedding_dim" in rotary_embedding_attributes:
return False
if (
"interleaved" in rotary_embedding_attributes
and rotary_embedding_attributes["interleaved"].value != 0
):
return False
return True

def rewrite(self, op, x, end1, x_part_1_rope, **_):
# Create a modified version of the RotaryEmbedding op:
rotary_embedding_dim = _ir_utils.get_singleton_value(end1)
original_node = x_part_1_rope.producer()
inputs = list(original_node.inputs)
inputs[0] = x
attrs = dict(original_node.attributes)
attrs["rotary_embedding_dim"] = rotary_embedding_dim
return op.RotaryEmbedding(
*inputs,
**attrs,
_domain="com.microsoft",
)


_rule = RotaryEmbeddingFusion.rule()

_partial_embedding_rule = PartialRotaryEmbeddingFusion.rule()

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])

partial_embedding_rules = pattern.RewriteRuleSet([_partial_embedding_rule])


def fuse_rotary_embedding(model: ir.Model) -> int:
count = rotary_embedding_rules.apply_to_model(model)
return count


def fuse_partial_rotary_embedding(model: ir.Model, debug: bool = False) -> int:
count = partial_embedding_rules.apply_to_model(model)
if count == 0 and debug:
tracer = pattern.MatchingTracer()
partial_embedding_rules.apply_to_model(model, tracer=tracer)
tracer.report()
return count
Loading