diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index d0c6a15cb7..a87d01e785 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -79,11 +79,15 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None -def get_singleton_value(val: ir.Value | None): - """Returns element of a single element tensor constant value, and None otherwise.""" +def get_singleton_value(val: ir.Value | None, rank: int | None = None): + """Returns element of a single element tensor constant value, and None otherwise. + + If rank is specified, it checks that the value has the given rank. + """ np_val = get_numpy_value(val) if np_val is not None and np_val.size == 1: - return np_val.item() + if rank is None or (np_val.ndim == rank): + return np_val.item() return None diff --git a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py index 9eb5a0b36e..bf5e7ba786 100644 --- a/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py +++ b/onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py @@ -10,6 +10,8 @@ from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import FLOAT, INT64 +# A simple rotary embedding example + # x: [B, H, S, E] # position_ids: [B, S] @@ -57,6 +59,7 @@ def test_case_1(): return _TestCase1() +# A simple rotary embedding example with 1D position_ids # x: [B, H, S, E] # position_ids: [S] @script() @@ -101,3 +104,67 @@ def get_ort_inputs(self): def test_case_2(): return _TestCase2() + + +# A partial rotary embedding example: + +rotary_embedding_dim = 32 # Abbreviated as "rd" in shape descriptors below +half_rotary_embedding_dim = rotary_embedding_dim // 2 +# A random inverse frequency tensor for the sake of this example. +inv_freqs_value = numpy.random.rand(1, half_rotary_embedding_dim, 1).astype(numpy.float32) + + +@script() +def _partial_rotary_script(position_ids, query): + inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1] + position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S] + position_ids_3d_float = op.Cast(position_ids_3d, to=1) + matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S] + transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2] + cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd] + cos_3d = op.Cos(cat) # [B, S, rd] + sin_3d = op.Sin(cat) # [B, S, rd] + # Split the query for partial embedding + to_embed = op.Slice(query, [0], [32], [3], [1]) + unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) + cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd] + sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd] + # Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X) + # essentially represents X rotated by 90 degrees + to_embed_times_cos = op.Mul(to_embed, cos_4d) + to_embed_x = op.Slice(to_embed, [0], [16], [3], [1]) + to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1]) + minus_to_embed_y = op.Neg(to_embed_y) + to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1) + to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d) + embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin) + final = op.Concat(embedded, unembedded, axis=-1) + return final + + +class _PartialRotaryTestCase: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = _partial_rotary_script.to_model_proto( + input_types=( + INT64["Batchsize", "Sequence"], + FLOAT["Batchsize", 32, "Sequence", 80], + ), + output_types=(FLOAT["Batchsize", 32, "Sequence", 80],), + ) + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "query": numpy.random.rand(1, 32, 8, 80).astype(numpy.float32), + "position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def partial_rotary_test_case(): + return _PartialRotaryTestCase() diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index d1a391e9ae..476226c6a2 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -152,16 +152,21 @@ def rewrite( _cast_const_freqs = CosSinCacheFusion.rule( "CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True ) -_cast = CosSinCacheFusion.rule( - "CosSinCache_cast_no_const_freqs", 2048, cast=True, const_freqs=False +_cast = CosSinCacheFusion.rule("CosSinCache_cast", 2048, cast=True, const_freqs=False) +_const_freqs = CosSinCacheFusion.rule( + "CosSinCache_const_freqs", 2048, cast=False, const_freqs=True ) _basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False) -cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic]) +cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic]) -def fuse_cos_sin_cache(model: ir.Model) -> int: +def fuse_cos_sin_cache(model: ir.Model, debug: bool = False) -> int: count = cos_sin_cache_rules.apply_to_model(model) + if count == 0 and debug: + tracer = pattern.MatchingTracer() + cos_sin_cache_rules.apply_to_model(model, tracer=tracer) + tracer.report() if count != 0: remove_unused_nodes(model) return count diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py index fcc735f2cc..67cb058fd3 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py @@ -7,11 +7,18 @@ from parameterized import parameterized import onnxscript.optimizer -from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1, test_case_2 +from onnxscript.rewriter.ort_fusions._rotary_embedding_models import ( + partial_rotary_test_case, + test_case_1, + test_case_2, +) from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1 from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache -from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding +from onnxscript.rewriter.ort_fusions.rotary_embedding import ( + fuse_partial_rotary_embedding, + fuse_rotary_embedding, +) class TestCosSinCacheTransform(unittest.TestCase): @@ -29,6 +36,10 @@ class TestCosSinCacheTransform(unittest.TestCase): "test_case_2", test_case_2, ), + ( + "partial_rotary_test_case", + partial_rotary_test_case, + ), ] ) def test_cos_sin_fusion(self, name, test_data_constructor): @@ -44,6 +55,21 @@ def test_cos_sin_fusion(self, name, test_data_constructor): new_outputs = ort_run("optimized", model, inputs) assert_allclose(new_outputs, original_outputs) + def test_partial_rotary_fusion(self): + test = partial_rotary_test_case() + model = test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_rotary_embedding(model) + self.assertGreater(count, 0) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + count = fuse_partial_rotary_embedding(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index d8ab31a428..c637fcc66f 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -53,11 +53,69 @@ def rewrite(self, op, x, cos, sin, **_): ) +class PartialRotaryEmbeddingFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, end1, start2): + x_part_1 = op.Slice(x, [0], end1, [3], [1]) + x_part_2 = op.Slice(x, start2, [9223372036854775807], [3], [1]) + x_part_1_rope = op.RotaryEmbedding( + x_part_1, + _allow_other_inputs=True, + _allow_other_attributes=True, + _domain="com.microsoft", + _outputs=["x_part_1_rope"], + ) + return op.Concat(x_part_1_rope, x_part_2, axis=-1) + + def check(self, op, x, end1, start2, x_part_1_rope, **_): + end1_value = _ir_utils.get_singleton_value(end1) + start2_value = _ir_utils.get_singleton_value(start2) + if not isinstance(end1_value, int) or not isinstance(start2_value, int): + return False + if end1_value != start2_value: + return False + rotary_embedding_attributes = x_part_1_rope.producer().attributes + if "rotary_embedding_dim" in rotary_embedding_attributes: + return False + if ( + "interleaved" in rotary_embedding_attributes + and rotary_embedding_attributes["interleaved"].value != 0 + ): + return False + return True + + def rewrite(self, op, x, end1, x_part_1_rope, **_): + # Create a modified version of the RotaryEmbedding op: + rotary_embedding_dim = _ir_utils.get_singleton_value(end1) + original_node = x_part_1_rope.producer() + inputs = list(original_node.inputs) + inputs[0] = x + attrs = dict(original_node.attributes) + attrs["rotary_embedding_dim"] = rotary_embedding_dim + return op.RotaryEmbedding( + *inputs, + **attrs, + _domain="com.microsoft", + ) + + _rule = RotaryEmbeddingFusion.rule() +_partial_embedding_rule = PartialRotaryEmbeddingFusion.rule() + rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) +partial_embedding_rules = pattern.RewriteRuleSet([_partial_embedding_rule]) + def fuse_rotary_embedding(model: ir.Model) -> int: count = rotary_embedding_rules.apply_to_model(model) return count + + +def fuse_partial_rotary_embedding(model: ir.Model, debug: bool = False) -> int: + count = partial_embedding_rules.apply_to_model(model) + if count == 0 and debug: + tracer = pattern.MatchingTracer() + partial_embedding_rules.apply_to_model(model, tracer=tracer) + tracer.report() + return count