Skip to content

Commit 882a442

Browse files
authored
Fusion for partial rotary embedding (#2095)
Add a fusion rule for recognizing partial rotary embedding, along with test case.
1 parent 32b54be commit 882a442

File tree

5 files changed

+169
-9
lines changed

5 files changed

+169
-9
lines changed

onnxscript/rewriter/_ir_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,15 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
7979
return None
8080

8181

82-
def get_singleton_value(val: ir.Value | None):
83-
"""Returns element of a single element tensor constant value, and None otherwise."""
82+
def get_singleton_value(val: ir.Value | None, rank: int | None = None):
83+
"""Returns element of a single element tensor constant value, and None otherwise.
84+
85+
If rank is specified, it checks that the value has the given rank.
86+
"""
8487
np_val = get_numpy_value(val)
8588
if np_val is not None and np_val.size == 1:
86-
return np_val.item()
89+
if rank is None or (np_val.ndim == rank):
90+
return np_val.item()
8791
return None
8892

8993

onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from onnxscript.onnx_opset import opset18 as op
1111
from onnxscript.onnx_types import FLOAT, INT64
1212

13+
# A simple rotary embedding example
14+
1315

1416
# x: [B, H, S, E]
1517
# position_ids: [B, S]
@@ -57,6 +59,7 @@ def test_case_1():
5759
return _TestCase1()
5860

5961

62+
# A simple rotary embedding example with 1D position_ids
6063
# x: [B, H, S, E]
6164
# position_ids: [S]
6265
@script()
@@ -101,3 +104,67 @@ def get_ort_inputs(self):
101104

102105
def test_case_2():
103106
return _TestCase2()
107+
108+
109+
# A partial rotary embedding example:
110+
111+
rotary_embedding_dim = 32 # Abbreviated as "rd" in shape descriptors below
112+
half_rotary_embedding_dim = rotary_embedding_dim // 2
113+
# A random inverse frequency tensor for the sake of this example.
114+
inv_freqs_value = numpy.random.rand(1, half_rotary_embedding_dim, 1).astype(numpy.float32)
115+
116+
117+
@script()
118+
def _partial_rotary_script(position_ids, query):
119+
inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1]
120+
position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S]
121+
position_ids_3d_float = op.Cast(position_ids_3d, to=1)
122+
matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S]
123+
transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2]
124+
cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd]
125+
cos_3d = op.Cos(cat) # [B, S, rd]
126+
sin_3d = op.Sin(cat) # [B, S, rd]
127+
# Split the query for partial embedding
128+
to_embed = op.Slice(query, [0], [32], [3], [1])
129+
unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1])
130+
cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd]
131+
sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd]
132+
# Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X)
133+
# essentially represents X rotated by 90 degrees
134+
to_embed_times_cos = op.Mul(to_embed, cos_4d)
135+
to_embed_x = op.Slice(to_embed, [0], [16], [3], [1])
136+
to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1])
137+
minus_to_embed_y = op.Neg(to_embed_y)
138+
to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1)
139+
to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d)
140+
embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin)
141+
final = op.Concat(embedded, unembedded, axis=-1)
142+
return final
143+
144+
145+
class _PartialRotaryTestCase:
146+
def get_onnx_model(self):
147+
if not hasattr(self, "_onnx_model"):
148+
model_proto = _partial_rotary_script.to_model_proto(
149+
input_types=(
150+
INT64["Batchsize", "Sequence"],
151+
FLOAT["Batchsize", 32, "Sequence", 80],
152+
),
153+
output_types=(FLOAT["Batchsize", 32, "Sequence", 80],),
154+
)
155+
model = ir.serde.deserialize_model(model_proto)
156+
self._onnx_model = model
157+
return self._onnx_model
158+
159+
def get_ort_inputs(self):
160+
if not hasattr(self, "_ort_inputs"):
161+
inputs = {
162+
"query": numpy.random.rand(1, 32, 8, 80).astype(numpy.float32),
163+
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8),
164+
}
165+
self._ort_inputs = inputs
166+
return self._ort_inputs
167+
168+
169+
def partial_rotary_test_case():
170+
return _PartialRotaryTestCase()

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,21 @@ def rewrite(
152152
_cast_const_freqs = CosSinCacheFusion.rule(
153153
"CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True
154154
)
155-
_cast = CosSinCacheFusion.rule(
156-
"CosSinCache_cast_no_const_freqs", 2048, cast=True, const_freqs=False
155+
_cast = CosSinCacheFusion.rule("CosSinCache_cast", 2048, cast=True, const_freqs=False)
156+
_const_freqs = CosSinCacheFusion.rule(
157+
"CosSinCache_const_freqs", 2048, cast=False, const_freqs=True
157158
)
158159
_basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)
159160

160-
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic])
161+
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic])
161162

162163

163-
def fuse_cos_sin_cache(model: ir.Model) -> int:
164+
def fuse_cos_sin_cache(model: ir.Model, debug: bool = False) -> int:
164165
count = cos_sin_cache_rules.apply_to_model(model)
166+
if count == 0 and debug:
167+
tracer = pattern.MatchingTracer()
168+
cos_sin_cache_rules.apply_to_model(model, tracer=tracer)
169+
tracer.report()
165170
if count != 0:
166171
remove_unused_nodes(model)
167172
return count

onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@
77
from parameterized import parameterized
88

99
import onnxscript.optimizer
10-
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1, test_case_2
10+
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import (
11+
partial_rotary_test_case,
12+
test_case_1,
13+
test_case_2,
14+
)
1115
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
1216
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
1317
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
14-
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
18+
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
19+
fuse_partial_rotary_embedding,
20+
fuse_rotary_embedding,
21+
)
1522

1623

1724
class TestCosSinCacheTransform(unittest.TestCase):
@@ -29,6 +36,10 @@ class TestCosSinCacheTransform(unittest.TestCase):
2936
"test_case_2",
3037
test_case_2,
3138
),
39+
(
40+
"partial_rotary_test_case",
41+
partial_rotary_test_case,
42+
),
3243
]
3344
)
3445
def test_cos_sin_fusion(self, name, test_data_constructor):
@@ -44,6 +55,21 @@ def test_cos_sin_fusion(self, name, test_data_constructor):
4455
new_outputs = ort_run("optimized", model, inputs)
4556
assert_allclose(new_outputs, original_outputs)
4657

58+
def test_partial_rotary_fusion(self):
59+
test = partial_rotary_test_case()
60+
model = test.get_onnx_model()
61+
onnxscript.optimizer.optimize(model)
62+
inputs = test.get_ort_inputs()
63+
original_outputs = ort_run("original", model, inputs)
64+
count = fuse_rotary_embedding(model)
65+
self.assertGreater(count, 0)
66+
count = fuse_cos_sin_cache(model)
67+
self.assertGreater(count, 0)
68+
count = fuse_partial_rotary_embedding(model)
69+
self.assertGreater(count, 0)
70+
new_outputs = ort_run("optimized", model, inputs)
71+
assert_allclose(new_outputs, original_outputs)
72+
4773

4874
if __name__ == "__main__":
4975
unittest.main()

onnxscript/rewriter/ort_fusions/rotary_embedding.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,69 @@ def rewrite(self, op, x, cos, sin, **_):
5353
)
5454

5555

56+
class PartialRotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
57+
def pattern(self, op, x, end1, start2):
58+
x_part_1 = op.Slice(x, [0], end1, [3], [1])
59+
x_part_2 = op.Slice(x, start2, [9223372036854775807], [3], [1])
60+
x_part_1_rope = op.RotaryEmbedding(
61+
x_part_1,
62+
_allow_other_inputs=True,
63+
_allow_other_attributes=True,
64+
_domain="com.microsoft",
65+
_outputs=["x_part_1_rope"],
66+
)
67+
return op.Concat(x_part_1_rope, x_part_2, axis=-1)
68+
69+
def check(self, op, x, end1, start2, x_part_1_rope, **_):
70+
end1_value = _ir_utils.get_singleton_value(end1)
71+
start2_value = _ir_utils.get_singleton_value(start2)
72+
if not isinstance(end1_value, int) or not isinstance(start2_value, int):
73+
return False
74+
if end1_value != start2_value:
75+
return False
76+
rotary_embedding_attributes = x_part_1_rope.producer().attributes
77+
if "rotary_embedding_dim" in rotary_embedding_attributes:
78+
return False
79+
if (
80+
"interleaved" in rotary_embedding_attributes
81+
and rotary_embedding_attributes["interleaved"].value != 0
82+
):
83+
return False
84+
return True
85+
86+
def rewrite(self, op, x, end1, x_part_1_rope, **_):
87+
# Create a modified version of the RotaryEmbedding op:
88+
rotary_embedding_dim = _ir_utils.get_singleton_value(end1)
89+
original_node = x_part_1_rope.producer()
90+
inputs = list(original_node.inputs)
91+
inputs[0] = x
92+
attrs = dict(original_node.attributes)
93+
attrs["rotary_embedding_dim"] = rotary_embedding_dim
94+
return op.RotaryEmbedding(
95+
*inputs,
96+
**attrs,
97+
_domain="com.microsoft",
98+
)
99+
100+
56101
_rule = RotaryEmbeddingFusion.rule()
57102

103+
_partial_embedding_rule = PartialRotaryEmbeddingFusion.rule()
104+
58105
rotary_embedding_rules = pattern.RewriteRuleSet([_rule])
59106

107+
partial_embedding_rules = pattern.RewriteRuleSet([_partial_embedding_rule])
108+
60109

61110
def fuse_rotary_embedding(model: ir.Model) -> int:
62111
count = rotary_embedding_rules.apply_to_model(model)
63112
return count
113+
114+
115+
def fuse_partial_rotary_embedding(model: ir.Model, debug: bool = False) -> int:
116+
count = partial_embedding_rules.apply_to_model(model)
117+
if count == 0 and debug:
118+
tracer = pattern.MatchingTracer()
119+
partial_embedding_rules.apply_to_model(model, tracer=tracer)
120+
tracer.report()
121+
return count

0 commit comments

Comments
 (0)