Skip to content

Extensions to transformer fusions #2082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ exclude_patterns = [
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/rewriter/ort_fusions/_smollm_*.py', # onnxscript code
'onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py', # onnxscript code
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
Expand Down
15 changes: 15 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@
return math.isclose(scalar, expected, rel_tol=rtol)


def is_1d_value(val: ir.Value | None, expected: list[int]) -> bool:
"""Returns True if the value is a 1d int64 tensor with given value, and False otherwise."""
if val is None:
return False

Check warning on line 109 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L109

Added line #L109 was not covered by tests
if not isinstance(val.type, ir.TypeProtocol):
return False

Check warning on line 111 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L111

Added line #L111 was not covered by tests
np_val = get_numpy_value(val)
if np_val is None:
return False

Check warning on line 114 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L114

Added line #L114 was not covered by tests
if (np_val.size != len(expected)) or (val.type.dtype != ir.DataType.INT64):
return False

Check warning on line 116 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L116

Added line #L116 was not covered by tests
values = np_val.tolist()
return values == expected


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
Expand Down
103 changes: 103 additions & 0 deletions onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Small test case models for rotary embedding."""

import numpy

import onnxscript.ir as ir
from onnxscript import script
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import FLOAT, INT64


# x: [B, H, S, E]
# position_ids: [B, S]
@script()
def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOAT[1, 4, 8, 8]:
inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0])
inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2])
position_ids_expanded = op.Unsqueeze(position_ids, [1]) # => [B, 1, S]
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq_3d, position_ids_float) # [B, E, S]
freqs = op.Transpose(freqs, perm=[0, 2, 1]) # [B, S, E]
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1)
sin_4d = op.Unsqueeze(sin, 1)

Check warning on line 28 in onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py#L18-L28

Added lines #L18 - L28 were not covered by tests

x1 = op.Slice(x, [0], [4], [3], [1])
x2 = op.Slice(x, [4], [8], [3], [1])
minus_x2 = op.Neg(x2)
rotated_x = op.Concat(minus_x2, x1, axis=-1)
rotary_embedding = op.Add(x * cos_4d, rotated_x * sin_4d)
return rotary_embedding

Check warning on line 35 in onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py#L30-L35

Added lines #L30 - L35 were not covered by tests


class _TestCase1:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = _test_case_1_script.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self._onnx_model = model
return self._onnx_model

def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
inputs = {
"x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32),
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8),
}
self._ort_inputs = inputs
return self._ort_inputs


def test_case_1():
return _TestCase1()


# x: [B, H, S, E]
# position_ids: [S]
@script()
def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1, 4, 8, 8]:
inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0])
inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2])
position_ids_expanded = op.Unsqueeze(position_ids, [0, 1]) # => [1, 1, S]
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq_3d, position_ids_float) # [B, E, S]
freqs = op.Transpose(freqs, perm=[0, 2, 1]) # [B, S, E]
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1)
sin_4d = op.Unsqueeze(sin, 1)

Check warning on line 74 in onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py#L64-L74

Added lines #L64 - L74 were not covered by tests

x1 = op.Slice(x, [0], [4], [3], [1])
x2 = op.Slice(x, [4], [8], [3], [1])
minus_x2 = op.Neg(x2)
rotated_x = op.Concat(minus_x2, x1, axis=-1)
rotary_embedding = op.Add(x * cos_4d, rotated_x * sin_4d)
return rotary_embedding

Check warning on line 81 in onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py#L76-L81

Added lines #L76 - L81 were not covered by tests


class _TestCase2:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = _test_case_2_script.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self._onnx_model = model
return self._onnx_model

def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
inputs = {
"x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32),
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(8),
}
self._ort_inputs = inputs
return self._ort_inputs


def test_case_2():
return _TestCase2()
6 changes: 5 additions & 1 deletion onnxscript/rewriter/ort_fusions/_smollm_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def make_model_with_random_weights():
return model


class TestData:
class _SmollmTest1:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = make_model_with_random_weights()
Expand All @@ -251,3 +251,7 @@ def get_ort_inputs(self):
}
self._ort_inputs = inputs
return self._ort_inputs


def smollm_test_1():
return _SmollmTest1()
6 changes: 5 additions & 1 deletion onnxscript/rewriter/ort_fusions/_smollm_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def make_model_with_random_weights():
return model


class TestData:
class _SmollmTest2:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = make_model_with_random_weights()
Expand All @@ -465,3 +465,7 @@ def get_ort_inputs(self):
}
self._ort_inputs = inputs
return self._ort_inputs


def smollm_test_2():
return _SmollmTest2()
30 changes: 19 additions & 11 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,18 @@ def __init__(
def cleanup(self):
self._inv_freq_cos_sin_cache.clear()

def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype):
def pattern(
self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, extra_dims
):
if not self._const_freqs:
# Compute freqs from inv_freq and position_ids. In the _const_freqs case,
# this computation has been constant-folded away and freqs is a constant.
# B: batch size, S: sequence length, E: embedding dimension
# position_ids: [B, S]
# position_ids: [B, S] or [S]
# inv_freq: [1, E, 1]
position_ids_expanded = op.Unsqueeze(position_ids, 1) # [B, S] => [B, 1, S]
position_ids_expanded = op.Unsqueeze(
position_ids, extra_dims
) # [B, S] | [S] => [B, 1, S]
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
# if self._reshape:
# position_ids_expanded = op.Expand(position_ids_expanded, _allow_other_inputs=True)
Expand All @@ -92,11 +96,17 @@ def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs,
_domain="ai.onnxruntime.fusion",
)

def check(self, context, inv_freq, position_ids, freqs, **_):
def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_):
# TODO(rama): handle redundant reshape/expand
if self._const_freqs:
return (freqs.const_value is not None) and _ir_utils.has_rank(freqs, 3)
if not _ir_utils.has_rank(position_ids, 2):
if (
_ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1)
) or (
_ir_utils.has_rank(position_ids, 1) and _ir_utils.is_1d_value(extra_dims, [0, 1])
):
pass
else:
return False
if not _ir_utils.has_rank(inv_freq, 3):
return False
Expand Down Expand Up @@ -125,6 +135,9 @@ def rewrite(
cos_2d = op.Cast(cos_2d, to=dtype)
sin_2d = op.Cast(sin_2d, to=dtype)
self._inv_freq_cos_sin_cache[inv_freq] = (cos_2d, sin_2d)
if _ir_utils.has_rank(position_ids, 1):
zero_1d = op.Constant(value_ints=[0])
position_ids = op.Unsqueeze(position_ids, zero_1d)
return op.RotaryEmbedding(
x,
position_ids,
Expand All @@ -146,14 +159,9 @@ def rewrite(

cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic])

debug: bool = True


def fuse_cos_sin_cache(model: ir.Model) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
if count == 0 and debug:
cos_sin_cache_rules.apply_to_model(model, debug=True)
else:
print(f"CosSinCache count: {count}")
if count != 0:
remove_unused_nodes(model)
return count
29 changes: 24 additions & 5 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,38 @@

import unittest

from parameterized import parameterized

import onnxscript.optimizer
from onnxscript.rewriter.ort_fusions._smollm_1 import TestData
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1, test_case_2
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding


class TestCosSinCacheTransform(unittest.TestCase):
def test_smollm(self):
smollm_test = TestData()
model = smollm_test.get_onnx_model()
@parameterized.expand(
[
(
"smollm_test_1",
smollm_test_1,
),
(
"test_case_1",
test_case_1,
),
(
"test_case_2",
test_case_2,
),
]
)
def test_cos_sin_fusion(self, name, test_data_constructor):
test = test_data_constructor()
model = test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
inputs = test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
count = fuse_rotary_embedding(model)
self.assertGreater(count, 0)
Expand Down
6 changes: 0 additions & 6 deletions onnxscript/rewriter/ort_fusions/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,7 @@ def rewrite(

mha_rules = pattern.RewriteRuleSet([_rule1])

debug: bool = True


def fuse_mha(model: ir.Model) -> int:
count = mha_rules.apply_to_model(model)
if count == 0 and debug:
mha_rules.apply_to_model(model, debug=True)
else:
print(f"MHA count: {count}")
return count
4 changes: 2 additions & 2 deletions onnxscript/rewriter/ort_fusions/mha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import onnxscript.optimizer
import onnxscript.rewriter.ort_fusions._core as xformers
from onnxscript.rewriter.ort_fusions._smollm_2 import TestData
from onnxscript.rewriter.ort_fusions._smollm_2 import smollm_test_2
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run


class TestMultiHeadAttention(unittest.TestCase):
def test_smollm(self):
# Generate model
smollm_test = TestData()
smollm_test = smollm_test_2()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
xformers.fuse_rms_normalization(model)
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/ort_fusions/rms_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import unittest

import onnxscript.optimizer
from onnxscript.rewriter.ort_fusions._smollm_1 import TestData
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization


class TestRmsNormalization(unittest.TestCase):
def test_smollm(self):
smollm_test = TestData()
smollm_test = smollm_test_1()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
Expand Down
6 changes: 0 additions & 6 deletions onnxscript/rewriter/ort_fusions/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,7 @@ def rewrite(self, op, x, cos, sin, **_):

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])

debug: bool = True


def fuse_rotary_embedding(model: ir.Model) -> int:
count = rotary_embedding_rules.apply_to_model(model)
if count == 0 and debug:
rotary_embedding_rules.apply_to_model(model, debug=True)
else:
print(f"Rotary Embedding count: {count}")
return count
23 changes: 19 additions & 4 deletions onnxscript/rewriter/ort_fusions/rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,30 @@

import unittest

from parameterized import parameterized

import onnxscript.optimizer
from onnxscript.rewriter.ort_fusions._smollm_1 import TestData
from onnxscript.rewriter.ort_fusions._rotary_embedding_models import test_case_1
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding


class TestRotaryEmbedding(unittest.TestCase):
def test_smollm(self):
smollm_test = TestData()
model = smollm_test.get_onnx_model()
@parameterized.expand(
[
(
"test_case_1",
test_case_1,
),
(
"smollm_test_1",
smollm_test_1,
),
]
)
def test_rotary_embedding_fusion(self, name, test_data_constructor):
test = test_data_constructor()
model = test.get_onnx_model()
onnxscript.optimizer.optimize(model)
fuse_rotary_embedding(model)
op_types = [n.op_type for n in model.graph]
Expand Down
6 changes: 0 additions & 6 deletions onnxscript/rewriter/ort_fusions/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
]
)

debug: bool = True


def fuse_sdpa(model: ir.Model) -> int:
count = sdpa_rules.apply_to_model(model)
if count == 0 and debug:
sdpa_rules.apply_to_model(model, debug=True)
else:
print(f"SDPA count: {count}")
return count
Loading
Loading