From 164815e8f3e4265a87d39ab4ba456b2cc6b987f2 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 10 Mar 2025 12:42:23 -0700 Subject: [PATCH 1/6] Add test case --- onnxscript/rewriter/_ir_utils.py | 10 ++- .../ort_fusions/_rotary_embedding_models.py | 73 +++++++++++++++++++ .../rewriter/ort_fusions/rotary_embedding.py | 49 +++++++++++++ 3 files changed, 129 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index c17443b9ba..51f7b911be 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -79,11 +79,15 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None -def get_singleton_value(val: ir.Value | None): - """Returns element of a single element tensor constant value, and None otherwise.""" +def get_singleton_value(val: ir.Value | None, rank: int | None = None): + """Returns element of a single element tensor constant value, and None otherwise. + + If rank is specified, it checks that the value has the given rank. + """ np_val = get_numpy_value(val) if np_val is not None and np_val.size == 1: - return np_val.item() + if rank is None or (np_val.ndim == rank): + return np_val.item() return None diff --git a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py index 9eb5a0b36e..c549377c79 100644 --- a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py +++ b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py @@ -10,6 +10,8 @@ from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import FLOAT, INT64 +# A simple rotary embedding example + # x: [B, H, S, E] # position_ids: [B, S] @@ -57,6 +59,7 @@ def test_case_1(): return _TestCase1() +# A simple rotary embedding example with 1D position_ids # x: [B, H, S, E] # position_ids: [S] @script() @@ -101,3 +104,73 @@ def get_ort_inputs(self): def test_case_2(): return _TestCase2() + + +# A partial rotary embedding example: + +# A random inverse frequency tensor for the sake of this example. +inv_freqs_value = numpy.random.rand(1, 40, 1).astype(numpy.float32) +# inv_freqs_value = make_tensor("value", 1, dims=[1, 40, 1], vals=[1.0]*40) + + +@script() +def _partial_rotary_script( + position_ids: INT64["Batchsize", "Sequence"], query: FLOAT["Batchsize", 32, "Sequence", 80] +) -> FLOAT["Batchsize", 32, "Sequence", 80]: + val_0 = op.Shape(position_ids, end=1, start=0) + sym_size_int_6 = op.Squeeze(val_0) + _to_copy_1 = op.Constant(value=inv_freqs_value) + val_25 = op.Reshape(sym_size_int_6, [-1], allowzero=0) + val_28 = op.Concat(val_25, [-1], [1], axis=0) + val_30 = op.Abs(val_28) + expand = op.Expand(_to_copy_1, val_30) + unsqueeze_2 = op.Unsqueeze(position_ids, 1) + _to_copy_2 = op.Cast(unsqueeze_2, to=1) + matmul = op.MatMul(expand, _to_copy_2) + transpose = op.Transpose(matmul, perm=[0, 2, 1]) + cat = op.Concat(transpose, transpose, axis=-1) + cos = op.Cos(cat) + sin = op.Sin(cat) + val_63 = op.Constant(value_ints=[1]) + slice_4 = op.Slice(query, [0], [32], [3], val_63) + val_73 = op.Constant(value_ints=[1]) + slice_5 = op.Slice(query, [32], [9223372036854775807], [3], val_73) + val_83 = op.Constant(value_ints=[1]) + slice_6 = op.Slice(cos, [0], [32], [2], val_83) + val_93 = op.Constant(value_ints=[1]) + slice_7 = op.Slice(sin, [0], [32], [2], val_93) + unsqueeze_3 = op.Unsqueeze(slice_6, 1) + unsqueeze_4 = op.Unsqueeze(slice_7, 1) + mul_55 = op.Mul(slice_4, unsqueeze_3) + val_106 = op.Constant(value_ints=[1]) + slice_8 = op.Slice(slice_4, [0], [16], [3], val_106) + val_116 = op.Constant(value_ints=[1]) + slice_9 = op.Slice(slice_4, [16], [9223372036854775807], [3], val_116) + neg = op.Neg(slice_9) + cat_1 = op.Concat(neg, slice_8, axis=-1) + mul_76 = op.Mul(cat_1, unsqueeze_4) + add_101 = op.Add(mul_55, mul_76) + cat_2 = op.Concat(add_101, slice_5, axis=-1) + return cat_2 + + +class _PartialRotaryTestCase: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = _partial_rotary_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "query": numpy.random.rand(1, 32, 8, 80).astype(numpy.float32), + "position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def partial_rotary_test_case(): + return _PartialRotaryTestCase() diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index d8ab31a428..c09033e824 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -53,10 +53,59 @@ def rewrite(self, op, x, cos, sin, **_): ) +class PartialRotaryEmbeddingFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, end1, start2): + x_part_1 = op.Slice(x, [0], end1, [3], [1]) + x_part_2 = op.Slice(x, start2, [9223372036854775807], [3], [1]) + x_part_1_rope = op.RotaryEmbedding( + x_part_1, + _allow_other_inputs=True, + _allow_other_attributes=True, + _domain="ai.onnxruntime.fusion", + _outputs=["x_part_1_rope"], + ) + return op.Concat(x_part_1_rope, x_part_2, axis=-1) + + def check(self, op, x, end1, start2, x_part_1_rope, **_): + end1_value = _ir_utils.get_singleton_value(end1) + start2_value = _ir_utils.get_singleton_value(start2) + if not isinstance(end1_value, int) or not isinstance(start2_value, int): + return False + if end1_value != start2_value: + return False + rotary_embedding_attributes = x_part_1_rope.producer().attributes + if "rotary_embedding_dim" in rotary_embedding_attributes: + return False + if ( + "interleaved" in rotary_embedding_attributes + and rotary_embedding_attributes["interleaved"].value != 0 + ): + return False + return True + + def rewrite(self, op, x, end1, x_part_1_rope, **_): + # Create a modified version of the RotaryEmbedding op: + rotary_embedding_dim = _ir_utils.get_singleton_value(end1) + original_node = x_part_1_rope.producer() + inputs = list(original_node.inputs) + inputs[0] = x + attrs = dict(original_node.attributes) + attrs["rotary_embedding_dim"] = rotary_embedding_dim + return op.RotaryEmbedding( + *inputs, + **attrs, + _domain="ai.onnxruntime.fusion", + ) + + _rule = RotaryEmbeddingFusion.rule() +_partial_embedding_rule = PartialRotaryEmbeddingFusion.rule() + rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) +partial_embedding_rules = pattern.RewriteRuleSet([_partial_embedding_rule]) + def fuse_rotary_embedding(model: ir.Model) -> int: count = rotary_embedding_rules.apply_to_model(model) From f22d29684437f4e680f78ef1918a985d78ab9220 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 10 Mar 2025 14:31:55 -0700 Subject: [PATCH 2/6] Partial rotary embedding fusion --- onnxscript/rewriter/llama_rule_sets.py | 1 + .../ort_fusions/_rotary_embedding_models.py | 26 ++++++++-------- .../rewriter/ort_fusions/cos_sin_cache.py | 15 +++++++--- .../ort_fusions/cos_sin_cache_test.py | 30 +++++++++++++++++-- .../rewriter/ort_fusions/rotary_embedding.py | 13 ++++++-- 5 files changed, 64 insertions(+), 21 deletions(-) diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 2dd3fd8e3f..17df20267c 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -304,5 +304,6 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet: transpose_identity_rule, transpose_transpose_rule, unsqueeze_unsqueeze_rule, + squeeze_reshape_1d_rule, ] ) diff --git a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py index c549377c79..50843f874c 100644 --- a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py +++ b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py @@ -109,7 +109,7 @@ def test_case_2(): # A partial rotary embedding example: # A random inverse frequency tensor for the sake of this example. -inv_freqs_value = numpy.random.rand(1, 40, 1).astype(numpy.float32) +inv_freqs_value = numpy.random.rand(1, 16, 1).astype(numpy.float32) # inv_freqs_value = make_tensor("value", 1, dims=[1, 40, 1], vals=[1.0]*40) @@ -117,13 +117,13 @@ def test_case_2(): def _partial_rotary_script( position_ids: INT64["Batchsize", "Sequence"], query: FLOAT["Batchsize", 32, "Sequence", 80] ) -> FLOAT["Batchsize", 32, "Sequence", 80]: - val_0 = op.Shape(position_ids, end=1, start=0) - sym_size_int_6 = op.Squeeze(val_0) + # val_0 = op.Shape(position_ids, end=1, start=0) + # sym_size_int_6 = op.Squeeze(val_0) _to_copy_1 = op.Constant(value=inv_freqs_value) - val_25 = op.Reshape(sym_size_int_6, [-1], allowzero=0) - val_28 = op.Concat(val_25, [-1], [1], axis=0) - val_30 = op.Abs(val_28) - expand = op.Expand(_to_copy_1, val_30) + # val_25 = op.Reshape(sym_size_int_6, [-1], allowzero=0) + # val_28 = op.Concat(val_25, [-1], [1], axis=0) + # val_30 = op.Abs(val_28) + expand = _to_copy_1 # op.Expand(_to_copy_1, val_30) unsqueeze_2 = op.Unsqueeze(position_ids, 1) _to_copy_2 = op.Cast(unsqueeze_2, to=1) matmul = op.MatMul(expand, _to_copy_2) @@ -135,12 +135,12 @@ def _partial_rotary_script( slice_4 = op.Slice(query, [0], [32], [3], val_63) val_73 = op.Constant(value_ints=[1]) slice_5 = op.Slice(query, [32], [9223372036854775807], [3], val_73) - val_83 = op.Constant(value_ints=[1]) - slice_6 = op.Slice(cos, [0], [32], [2], val_83) - val_93 = op.Constant(value_ints=[1]) - slice_7 = op.Slice(sin, [0], [32], [2], val_93) - unsqueeze_3 = op.Unsqueeze(slice_6, 1) - unsqueeze_4 = op.Unsqueeze(slice_7, 1) + # val_83 = op.Constant(value_ints=[1]) + # slice_6 = op.Slice(cos, [0], [32], [2], val_83) + # val_93 = op.Constant(value_ints=[1]) + # slice_7 = op.Slice(sin, [0], [32], [2], val_93) + unsqueeze_3 = op.Unsqueeze(cos, 1) + unsqueeze_4 = op.Unsqueeze(sin, 1) mul_55 = op.Mul(slice_4, unsqueeze_3) val_106 = op.Constant(value_ints=[1]) slice_8 = op.Slice(slice_4, [0], [16], [3], val_106) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index d1a391e9ae..3b2de698b1 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -152,16 +152,23 @@ def rewrite( _cast_const_freqs = CosSinCacheFusion.rule( "CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True ) -_cast = CosSinCacheFusion.rule( - "CosSinCache_cast_no_const_freqs", 2048, cast=True, const_freqs=False +_cast = CosSinCacheFusion.rule("CosSinCache_cast", 2048, cast=True, const_freqs=False) +_const_freqs = CosSinCacheFusion.rule( + "CosSinCache_const_freqs", 2048, cast=False, const_freqs=True ) _basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False) -cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic]) +# cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic]) +cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _basic]) -def fuse_cos_sin_cache(model: ir.Model) -> int: + +def fuse_cos_sin_cache(model: ir.Model, debug: bool) -> int: count = cos_sin_cache_rules.apply_to_model(model) + if count == 0 and debug: + tracer = pattern.MatchingTracer() + cos_sin_cache_rules.apply_to_model(model, tracer=tracer) + tracer.report() if count != 0: remove_unused_nodes(model) return count diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py index fcc735f2cc..67cb058fd3 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py @@ -7,11 +7,18 @@ from parameterized import parameterized import onnxscript.optimizer -from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1, test_case_2 +from onnxscript.rewriter.ort_fusions._rotary_embedding_models import ( + partial_rotary_test_case, + test_case_1, + test_case_2, +) from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1 from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache -from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding +from onnxscript.rewriter.ort_fusions.rotary_embedding import ( + fuse_partial_rotary_embedding, + fuse_rotary_embedding, +) class TestCosSinCacheTransform(unittest.TestCase): @@ -29,6 +36,10 @@ class TestCosSinCacheTransform(unittest.TestCase): "test_case_2", test_case_2, ), + ( + "partial_rotary_test_case", + partial_rotary_test_case, + ), ] ) def test_cos_sin_fusion(self, name, test_data_constructor): @@ -44,6 +55,21 @@ def test_cos_sin_fusion(self, name, test_data_constructor): new_outputs = ort_run("optimized", model, inputs) assert_allclose(new_outputs, original_outputs) + def test_partial_rotary_fusion(self): + test = partial_rotary_test_case() + model = test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_rotary_embedding(model) + self.assertGreater(count, 0) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + count = fuse_partial_rotary_embedding(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index c09033e824..c637fcc66f 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -61,7 +61,7 @@ def pattern(self, op, x, end1, start2): x_part_1, _allow_other_inputs=True, _allow_other_attributes=True, - _domain="ai.onnxruntime.fusion", + _domain="com.microsoft", _outputs=["x_part_1_rope"], ) return op.Concat(x_part_1_rope, x_part_2, axis=-1) @@ -94,7 +94,7 @@ def rewrite(self, op, x, end1, x_part_1_rope, **_): return op.RotaryEmbedding( *inputs, **attrs, - _domain="ai.onnxruntime.fusion", + _domain="com.microsoft", ) @@ -110,3 +110,12 @@ def rewrite(self, op, x, end1, x_part_1_rope, **_): def fuse_rotary_embedding(model: ir.Model) -> int: count = rotary_embedding_rules.apply_to_model(model) return count + + +def fuse_partial_rotary_embedding(model: ir.Model, debug: bool = False) -> int: + count = partial_embedding_rules.apply_to_model(model) + if count == 0 and debug: + tracer = pattern.MatchingTracer() + partial_embedding_rules.apply_to_model(model, tracer=tracer) + tracer.report() + return count From 653cd1a1c52007bc97cb2020078fed745224cbb0 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 10 Mar 2025 15:15:09 -0700 Subject: [PATCH 3/6] Minor fixes --- .../ort_fusions/_rotary_embedding_models.py | 14 ++------------ onnxscript/rewriter/ort_fusions/cos_sin_cache.py | 4 +--- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py index 50843f874c..885d94931e 100644 --- a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py +++ b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py @@ -117,16 +117,10 @@ def test_case_2(): def _partial_rotary_script( position_ids: INT64["Batchsize", "Sequence"], query: FLOAT["Batchsize", 32, "Sequence", 80] ) -> FLOAT["Batchsize", 32, "Sequence", 80]: - # val_0 = op.Shape(position_ids, end=1, start=0) - # sym_size_int_6 = op.Squeeze(val_0) - _to_copy_1 = op.Constant(value=inv_freqs_value) - # val_25 = op.Reshape(sym_size_int_6, [-1], allowzero=0) - # val_28 = op.Concat(val_25, [-1], [1], axis=0) - # val_30 = op.Abs(val_28) - expand = _to_copy_1 # op.Expand(_to_copy_1, val_30) + inv_freqs = op.Constant(value=inv_freqs_value) unsqueeze_2 = op.Unsqueeze(position_ids, 1) _to_copy_2 = op.Cast(unsqueeze_2, to=1) - matmul = op.MatMul(expand, _to_copy_2) + matmul = op.MatMul(inv_freqs, _to_copy_2) transpose = op.Transpose(matmul, perm=[0, 2, 1]) cat = op.Concat(transpose, transpose, axis=-1) cos = op.Cos(cat) @@ -135,10 +129,6 @@ def _partial_rotary_script( slice_4 = op.Slice(query, [0], [32], [3], val_63) val_73 = op.Constant(value_ints=[1]) slice_5 = op.Slice(query, [32], [9223372036854775807], [3], val_73) - # val_83 = op.Constant(value_ints=[1]) - # slice_6 = op.Slice(cos, [0], [32], [2], val_83) - # val_93 = op.Constant(value_ints=[1]) - # slice_7 = op.Slice(sin, [0], [32], [2], val_93) unsqueeze_3 = op.Unsqueeze(cos, 1) unsqueeze_4 = op.Unsqueeze(sin, 1) mul_55 = op.Mul(slice_4, unsqueeze_3) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index 3b2de698b1..001484152a 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -158,9 +158,7 @@ def rewrite( ) _basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False) -# cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic]) - -cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _basic]) +cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic]) def fuse_cos_sin_cache(model: ir.Model, debug: bool) -> int: From fee229d56953d0c0394906ec6d4b9bfd05797053 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 10 Mar 2025 15:28:17 -0700 Subject: [PATCH 4/6] Make debug parameter optional --- onnxscript/rewriter/ort_fusions/cos_sin_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index 001484152a..476226c6a2 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -161,7 +161,7 @@ def rewrite( cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic]) -def fuse_cos_sin_cache(model: ir.Model, debug: bool) -> int: +def fuse_cos_sin_cache(model: ir.Model, debug: bool = False) -> int: count = cos_sin_cache_rules.apply_to_model(model) if count == 0 and debug: tracer = pattern.MatchingTracer() From 4f124c83224ecd6de731d3fa8930a12cccaefa24 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 11 Mar 2025 15:41:10 -0700 Subject: [PATCH 5/6] Rewrite test case --- .../ort_fusions/_rotary_embedding_models.py | 63 ++++++++++--------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py index 885d94931e..b76ccb036f 100644 --- a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py +++ b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py @@ -108,46 +108,47 @@ def test_case_2(): # A partial rotary embedding example: +rotary_embedding_dim = 32 # Abbreviated as "rd" in shape descriptors below +half_rotary_embedding_dim = rotary_embedding_dim // 2 # A random inverse frequency tensor for the sake of this example. -inv_freqs_value = numpy.random.rand(1, 16, 1).astype(numpy.float32) -# inv_freqs_value = make_tensor("value", 1, dims=[1, 40, 1], vals=[1.0]*40) +inv_freqs_value = numpy.random.rand(1, half_rotary_embedding_dim, 1).astype(numpy.float32) @script() -def _partial_rotary_script( - position_ids: INT64["Batchsize", "Sequence"], query: FLOAT["Batchsize", 32, "Sequence", 80] -) -> FLOAT["Batchsize", 32, "Sequence", 80]: - inv_freqs = op.Constant(value=inv_freqs_value) - unsqueeze_2 = op.Unsqueeze(position_ids, 1) - _to_copy_2 = op.Cast(unsqueeze_2, to=1) - matmul = op.MatMul(inv_freqs, _to_copy_2) - transpose = op.Transpose(matmul, perm=[0, 2, 1]) - cat = op.Concat(transpose, transpose, axis=-1) - cos = op.Cos(cat) - sin = op.Sin(cat) - val_63 = op.Constant(value_ints=[1]) - slice_4 = op.Slice(query, [0], [32], [3], val_63) - val_73 = op.Constant(value_ints=[1]) - slice_5 = op.Slice(query, [32], [9223372036854775807], [3], val_73) - unsqueeze_3 = op.Unsqueeze(cos, 1) - unsqueeze_4 = op.Unsqueeze(sin, 1) - mul_55 = op.Mul(slice_4, unsqueeze_3) - val_106 = op.Constant(value_ints=[1]) - slice_8 = op.Slice(slice_4, [0], [16], [3], val_106) - val_116 = op.Constant(value_ints=[1]) - slice_9 = op.Slice(slice_4, [16], [9223372036854775807], [3], val_116) - neg = op.Neg(slice_9) - cat_1 = op.Concat(neg, slice_8, axis=-1) - mul_76 = op.Mul(cat_1, unsqueeze_4) - add_101 = op.Add(mul_55, mul_76) - cat_2 = op.Concat(add_101, slice_5, axis=-1) - return cat_2 +def _partial_rotary_script(position_ids, query): + inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1] + position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S] + position_ids_3d_float = op.Cast(position_ids_3d, to=1) + matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S] + transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2] + cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd] + cos_3d = op.Cos(cat) # [B, S, rd] + sin_3d = op.Sin(cat) # [B, S, rd] + to_embed = op.Slice(query, [0], [32], [3], [1]) + unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) + cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd] + sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd] + to_embed_times_cos = op.Mul(to_embed, cos_4d) + to_embed_x = op.Slice(to_embed, [0], [16], [3], [1]) + to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1]) + minus_to_embed_y = op.Neg(to_embed_y) + to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1) + to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d) + embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin) + final = op.Concat(embedded, unembedded, axis=-1) + return final class _PartialRotaryTestCase: def get_onnx_model(self): if not hasattr(self, "_onnx_model"): - model_proto = _partial_rotary_script.to_model_proto() + model_proto = _partial_rotary_script.to_model_proto( + input_types=( + INT64["Batchsize", "Sequence"], + FLOAT["Batchsize", 32, "Sequence", 80], + ), + output_types=(FLOAT["Batchsize", 32, "Sequence", 80],), + ) model = ir.serde.deserialize_model(model_proto) self._onnx_model = model return self._onnx_model From 23f20325e6917f715d86b9d9156353f016230516 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 11 Mar 2025 16:43:17 -0700 Subject: [PATCH 6/6] Add comment --- onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py index b76ccb036f..bf5e7ba786 100644 --- a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py +++ b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py @@ -124,10 +124,13 @@ def _partial_rotary_script(position_ids, query): cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd] cos_3d = op.Cos(cat) # [B, S, rd] sin_3d = op.Sin(cat) # [B, S, rd] + # Split the query for partial embedding to_embed = op.Slice(query, [0], [32], [3], [1]) unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd] sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd] + # Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X) + # essentially represents X rotated by 90 degrees to_embed_times_cos = op.Mul(to_embed, cos_4d) to_embed_x = op.Slice(to_embed, [0], [16], [3], [1]) to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1])