Skip to content

Commit debc34d

Browse files
authored
[DRAFT] Extensions to transformer fusions (#2082)
* Extends the cos-sin-cache fusion to support 1D position-id (without batch dimension) * Make MatchingTracer a parameter of the rewriter to give users better control over how to report stats (for successful or failing matches) * Improve the tracer output
1 parent 8ad2403 commit debc34d

16 files changed

+251
-81
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ exclude_patterns = [
5151
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
5252
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
5353
'onnxscript/rewriter/ort_fusions/_smollm_*.py', # onnxscript code
54+
'onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py', # onnxscript code
5455
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
5556
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
5657
'onnxscript/tools/function_unittest_producer.py', # FIXME

onnxscript/rewriter/_ir_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,21 @@ def is_singleton_value(
103103
return math.isclose(scalar, expected, rel_tol=rtol)
104104

105105

106+
def is_1d_value(val: ir.Value | None, expected: list[int]) -> bool:
107+
"""Returns True if the value is a 1d int64 tensor with given value, and False otherwise."""
108+
if val is None:
109+
return False
110+
if not isinstance(val.type, ir.TypeProtocol):
111+
return False
112+
np_val = get_numpy_value(val)
113+
if np_val is None:
114+
return False
115+
if (np_val.size != len(expected)) or (val.type.dtype != ir.DataType.INT64):
116+
return False
117+
values = np_val.tolist()
118+
return values == expected
119+
120+
106121
def has_rank(value: ir.Value | None, rank: int) -> bool:
107122
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
108123
if value is None:
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Small test case models for rotary embedding."""
5+
6+
import numpy
7+
8+
import onnxscript.ir as ir
9+
from onnxscript import script
10+
from onnxscript.onnx_opset import opset18 as op
11+
from onnxscript.onnx_types import FLOAT, INT64
12+
13+
14+
# x: [B, H, S, E]
15+
# position_ids: [B, S]
16+
@script()
17+
def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOAT[1, 4, 8, 8]:
18+
inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0])
19+
inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2])
20+
position_ids_expanded = op.Unsqueeze(position_ids, [1]) # => [B, 1, S]
21+
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
22+
freqs = op.MatMul(inv_freq_3d, position_ids_float) # [B, E, S]
23+
freqs = op.Transpose(freqs, perm=[0, 2, 1]) # [B, S, E]
24+
emb = op.Concat(freqs, freqs, axis=-1)
25+
cos = op.Cos(emb)
26+
sin = op.Sin(emb)
27+
cos_4d = op.Unsqueeze(cos, 1)
28+
sin_4d = op.Unsqueeze(sin, 1)
29+
30+
x1 = op.Slice(x, [0], [4], [3], [1])
31+
x2 = op.Slice(x, [4], [8], [3], [1])
32+
minus_x2 = op.Neg(x2)
33+
rotated_x = op.Concat(minus_x2, x1, axis=-1)
34+
rotary_embedding = op.Add(x * cos_4d, rotated_x * sin_4d)
35+
return rotary_embedding
36+
37+
38+
class _TestCase1:
39+
def get_onnx_model(self):
40+
if not hasattr(self, "_onnx_model"):
41+
model_proto = _test_case_1_script.to_model_proto()
42+
model = ir.serde.deserialize_model(model_proto)
43+
self._onnx_model = model
44+
return self._onnx_model
45+
46+
def get_ort_inputs(self):
47+
if not hasattr(self, "_ort_inputs"):
48+
inputs = {
49+
"x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32),
50+
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8),
51+
}
52+
self._ort_inputs = inputs
53+
return self._ort_inputs
54+
55+
56+
def test_case_1():
57+
return _TestCase1()
58+
59+
60+
# x: [B, H, S, E]
61+
# position_ids: [S]
62+
@script()
63+
def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1, 4, 8, 8]:
64+
inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0])
65+
inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2])
66+
position_ids_expanded = op.Unsqueeze(position_ids, [0, 1]) # => [1, 1, S]
67+
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
68+
freqs = op.MatMul(inv_freq_3d, position_ids_float) # [B, E, S]
69+
freqs = op.Transpose(freqs, perm=[0, 2, 1]) # [B, S, E]
70+
emb = op.Concat(freqs, freqs, axis=-1)
71+
cos = op.Cos(emb)
72+
sin = op.Sin(emb)
73+
cos_4d = op.Unsqueeze(cos, 1)
74+
sin_4d = op.Unsqueeze(sin, 1)
75+
76+
x1 = op.Slice(x, [0], [4], [3], [1])
77+
x2 = op.Slice(x, [4], [8], [3], [1])
78+
minus_x2 = op.Neg(x2)
79+
rotated_x = op.Concat(minus_x2, x1, axis=-1)
80+
rotary_embedding = op.Add(x * cos_4d, rotated_x * sin_4d)
81+
return rotary_embedding
82+
83+
84+
class _TestCase2:
85+
def get_onnx_model(self):
86+
if not hasattr(self, "_onnx_model"):
87+
model_proto = _test_case_2_script.to_model_proto()
88+
model = ir.serde.deserialize_model(model_proto)
89+
self._onnx_model = model
90+
return self._onnx_model
91+
92+
def get_ort_inputs(self):
93+
if not hasattr(self, "_ort_inputs"):
94+
inputs = {
95+
"x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32),
96+
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(8),
97+
}
98+
self._ort_inputs = inputs
99+
return self._ort_inputs
100+
101+
102+
def test_case_2():
103+
return _TestCase2()

onnxscript/rewriter/ort_fusions/_smollm_1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def make_model_with_random_weights():
234234
return model
235235

236236

237-
class TestData:
237+
class _SmollmTest1:
238238
def get_onnx_model(self):
239239
if not hasattr(self, "_onnx_model"):
240240
model_proto = make_model_with_random_weights()
@@ -251,3 +251,7 @@ def get_ort_inputs(self):
251251
}
252252
self._ort_inputs = inputs
253253
return self._ort_inputs
254+
255+
256+
def smollm_test_1():
257+
return _SmollmTest1()

onnxscript/rewriter/ort_fusions/_smollm_2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def make_model_with_random_weights():
447447
return model
448448

449449

450-
class TestData:
450+
class _SmollmTest2:
451451
def get_onnx_model(self):
452452
if not hasattr(self, "_onnx_model"):
453453
model_proto = make_model_with_random_weights()
@@ -465,3 +465,7 @@ def get_ort_inputs(self):
465465
}
466466
self._ort_inputs = inputs
467467
return self._ort_inputs
468+
469+
470+
def smollm_test_2():
471+
return _SmollmTest2()

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,18 @@ def __init__(
5858
def cleanup(self):
5959
self._inv_freq_cos_sin_cache.clear()
6060

61-
def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype):
61+
def pattern(
62+
self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, extra_dims
63+
):
6264
if not self._const_freqs:
6365
# Compute freqs from inv_freq and position_ids. In the _const_freqs case,
6466
# this computation has been constant-folded away and freqs is a constant.
6567
# B: batch size, S: sequence length, E: embedding dimension
66-
# position_ids: [B, S]
68+
# position_ids: [B, S] or [S]
6769
# inv_freq: [1, E, 1]
68-
position_ids_expanded = op.Unsqueeze(position_ids, 1) # [B, S] => [B, 1, S]
70+
position_ids_expanded = op.Unsqueeze(
71+
position_ids, extra_dims
72+
) # [B, S] | [S] => [B, 1, S]
6973
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
7074
# if self._reshape:
7175
# position_ids_expanded = op.Expand(position_ids_expanded, _allow_other_inputs=True)
@@ -92,11 +96,17 @@ def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs,
9296
_domain="ai.onnxruntime.fusion",
9397
)
9498

95-
def check(self, context, inv_freq, position_ids, freqs, **_):
99+
def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_):
96100
# TODO(rama): handle redundant reshape/expand
97101
if self._const_freqs:
98102
return (freqs.const_value is not None) and _ir_utils.has_rank(freqs, 3)
99-
if not _ir_utils.has_rank(position_ids, 2):
103+
if (
104+
_ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1)
105+
) or (
106+
_ir_utils.has_rank(position_ids, 1) and _ir_utils.is_1d_value(extra_dims, [0, 1])
107+
):
108+
pass
109+
else:
100110
return False
101111
if not _ir_utils.has_rank(inv_freq, 3):
102112
return False
@@ -125,6 +135,9 @@ def rewrite(
125135
cos_2d = op.Cast(cos_2d, to=dtype)
126136
sin_2d = op.Cast(sin_2d, to=dtype)
127137
self._inv_freq_cos_sin_cache[inv_freq] = (cos_2d, sin_2d)
138+
if _ir_utils.has_rank(position_ids, 1):
139+
zero_1d = op.Constant(value_ints=[0])
140+
position_ids = op.Unsqueeze(position_ids, zero_1d)
128141
return op.RotaryEmbedding(
129142
x,
130143
position_ids,
@@ -146,14 +159,9 @@ def rewrite(
146159

147160
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic])
148161

149-
debug: bool = True
150-
151162

152163
def fuse_cos_sin_cache(model: ir.Model) -> int:
153164
count = cos_sin_cache_rules.apply_to_model(model)
154-
if count == 0 and debug:
155-
cos_sin_cache_rules.apply_to_model(model, debug=True)
156-
else:
157-
print(f"CosSinCache count: {count}")
165+
if count != 0:
158166
remove_unused_nodes(model)
159167
return count

onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,38 @@
44

55
import unittest
66

7+
from parameterized import parameterized
8+
79
import onnxscript.optimizer
8-
from onnxscript.rewriter.ort_fusions._smollm_1 import TestData
10+
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1, test_case_2
11+
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
912
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
1013
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
1114
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
1215

1316

1417
class TestCosSinCacheTransform(unittest.TestCase):
15-
def test_smollm(self):
16-
smollm_test = TestData()
17-
model = smollm_test.get_onnx_model()
18+
@parameterized.expand(
19+
[
20+
(
21+
"smollm_test_1",
22+
smollm_test_1,
23+
),
24+
(
25+
"test_case_1",
26+
test_case_1,
27+
),
28+
(
29+
"test_case_2",
30+
test_case_2,
31+
),
32+
]
33+
)
34+
def test_cos_sin_fusion(self, name, test_data_constructor):
35+
test = test_data_constructor()
36+
model = test.get_onnx_model()
1837
onnxscript.optimizer.optimize(model)
19-
inputs = smollm_test.get_ort_inputs()
38+
inputs = test.get_ort_inputs()
2039
original_outputs = ort_run("original", model, inputs)
2140
count = fuse_rotary_embedding(model)
2241
self.assertGreater(count, 0)

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,7 @@ def rewrite(
186186

187187
mha_rules = pattern.RewriteRuleSet([_rule1])
188188

189-
debug: bool = True
190-
191189

192190
def fuse_mha(model: ir.Model) -> int:
193191
count = mha_rules.apply_to_model(model)
194-
if count == 0 and debug:
195-
mha_rules.apply_to_model(model, debug=True)
196-
else:
197-
print(f"MHA count: {count}")
198192
return count

onnxscript/rewriter/ort_fusions/mha_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
import onnxscript.optimizer
88
import onnxscript.rewriter.ort_fusions._core as xformers
9-
from onnxscript.rewriter.ort_fusions._smollm_2 import TestData
9+
from onnxscript.rewriter.ort_fusions._smollm_2 import smollm_test_2
1010
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
1111

1212

1313
class TestMultiHeadAttention(unittest.TestCase):
1414
def test_smollm(self):
1515
# Generate model
16-
smollm_test = TestData()
16+
smollm_test = smollm_test_2()
1717
model = smollm_test.get_onnx_model()
1818
onnxscript.optimizer.optimize(model)
1919
xformers.fuse_rms_normalization(model)

onnxscript/rewriter/ort_fusions/rms_normalization_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import unittest
66

77
import onnxscript.optimizer
8-
from onnxscript.rewriter.ort_fusions._smollm_1 import TestData
8+
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
99
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
1010
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
1111

1212

1313
class TestRmsNormalization(unittest.TestCase):
1414
def test_smollm(self):
15-
smollm_test = TestData()
15+
smollm_test = smollm_test_1()
1616
model = smollm_test.get_onnx_model()
1717
onnxscript.optimizer.optimize(model)
1818
inputs = smollm_test.get_ort_inputs()

onnxscript/rewriter/ort_fusions/rotary_embedding.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,7 @@ def rewrite(self, op, x, cos, sin, **_):
5757

5858
rotary_embedding_rules = pattern.RewriteRuleSet([_rule])
5959

60-
debug: bool = True
61-
6260

6361
def fuse_rotary_embedding(model: ir.Model) -> int:
6462
count = rotary_embedding_rules.apply_to_model(model)
65-
if count == 0 and debug:
66-
rotary_embedding_rules.apply_to_model(model, debug=True)
67-
else:
68-
print(f"Rotary Embedding count: {count}")
6963
return count

onnxscript/rewriter/ort_fusions/rotary_embedding_test.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,30 @@
44

55
import unittest
66

7+
from parameterized import parameterized
8+
79
import onnxscript.optimizer
8-
from onnxscript.rewriter.ort_fusions._smollm_1 import TestData
10+
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1
11+
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
912
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
1013

1114

1215
class TestRotaryEmbedding(unittest.TestCase):
13-
def test_smollm(self):
14-
smollm_test = TestData()
15-
model = smollm_test.get_onnx_model()
16+
@parameterized.expand(
17+
[
18+
(
19+
"test_case_1",
20+
test_case_1,
21+
),
22+
(
23+
"smollm_test_1",
24+
smollm_test_1,
25+
),
26+
]
27+
)
28+
def test_rotary_embedding_fusion(self, name, test_data_constructor):
29+
test = test_data_constructor()
30+
model = test.get_onnx_model()
1631
onnxscript.optimizer.optimize(model)
1732
fuse_rotary_embedding(model)
1833
op_types = [n.op_type for n in model.graph]

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
9595
]
9696
)
9797

98-
debug: bool = True
99-
10098

10199
def fuse_sdpa(model: ir.Model) -> int:
102100
count = sdpa_rules.apply_to_model(model)
103-
if count == 0 and debug:
104-
sdpa_rules.apply_to_model(model, debug=True)
105-
else:
106-
print(f"SDPA count: {count}")
107101
return count

0 commit comments

Comments
 (0)