Skip to content

Commit edc3106

Browse files
gramalingamjustinchuby
authored andcommitted
A couple of ort fusion fixes (microsoft#2136)
* Enable the use of SDPA fusions, along with undoing it when it does not lead to some subsequent final fusion (such as MHA or GQA). * Fix the use of constants in extracted functions from fusion. * Fix the use of Gelu instead of FastGelu in the new fusion introduced earlier today. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 2a26845 commit edc3106

File tree

9 files changed

+58
-14
lines changed

9 files changed

+58
-14
lines changed

noxfile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
"beartype==0.17.2",
1616
"expecttest==0.1.6",
1717
"hypothesis",
18-
'numpy==1.24.4; python_version<"3.9"',
19-
'numpy==1.26.4; python_version>="3.9"',
18+
"numpy",
2019
"packaging",
2120
"parameterized",
2221
'psutil; sys_platform != "win32"',

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
# TODO: There are some potential redundancies below. Can be targeted for optimization
2222
# once we have robust fusion.
2323
def _pre_optimize(model: ir.Model) -> ir.Model:
24-
optimize(model)
2524
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
2625
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
2726
# incorporated in our optimizer.
@@ -30,7 +29,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
3029
return model
3130

3231

33-
def fuse_xformers(model: ir.Model) -> None:
32+
def fuse_xformers(model: ir.Model) -> ir.Model:
3433
model = _pre_optimize(model)
3534
fuse_rms_normalization(model)
3635
fuse_normalization(model)
@@ -40,7 +39,10 @@ def fuse_xformers(model: ir.Model) -> None:
4039
fuse_sdpa(model)
4140
fuse_mha(model)
4241
fuse_gelu(model)
43-
remove_unused_nodes(model)
42+
# Finally: inline any intermediate fusion functions introduced that were not
43+
# consumed by other fusions, and eliminate any remaining unused nodes.
44+
optimize(model)
45+
return model
4446

4547

4648
def optimize_for_ort(model: ir.Model) -> None:

onnxscript/rewriter/ort_fusions/_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def ort_run(model_name: str, model, inputs):
3232
return ort_outputs
3333

3434

35-
def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2):
35+
def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
3636
for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)):
3737
try:
3838
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import onnxscript.optimizer
8+
from onnxscript.rewriter.ort_fusions._core import fuse_xformers
9+
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
10+
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
11+
12+
13+
class TestFuseXformers(unittest.TestCase):
14+
def test_fuse_xformers(self):
15+
test = smollm_test_1()
16+
model = test.get_onnx_model()
17+
onnxscript.optimizer.optimize(model)
18+
inputs = test.get_ort_inputs()
19+
original_outputs = ort_run("original", model, inputs)
20+
model = fuse_xformers(model)
21+
new_outputs = ort_run("optimized", model, inputs)
22+
assert_allclose(new_outputs, original_outputs)
23+
24+
25+
if __name__ == "__main__":
26+
unittest.main()

onnxscript/rewriter/ort_fusions/gelu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def pattern(self, op, x):
2525
return result
2626

2727
def rewrite(self, op, x):
28-
return op.Gelu(x, _domain="com.microsoft")
28+
return op.FastGelu(x, _domain="com.microsoft")
2929

3030

3131
_rule = GeluTanhFusion.rule()

onnxscript/rewriter/ort_fusions/gelu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def gelu_model(x):
4747
remove_unused_nodes(model)
4848

4949
self.assertEqual(len(model.graph), 1)
50-
self.assertEqual(model.graph.node(0).op_type, "Gelu")
50+
self.assertEqual(model.graph.node(0).op_type, "FastGelu")
5151

5252
optimized_output = test_utils.ort_run("Optimized", model, input)
5353
test_utils.assert_allclose(original_output, optimized_output)

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def check(self, op, x, scale, epsilon, compute_dtype, target_dtype):
7171
def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
7272
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
7373
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
74-
# No need to use com.microsoft domain here.
74+
# No need to use com.microsoft domain here; but this is a custom op in ORT.
7575
return op.SimplifiedLayerNormalization(
7676
x,
7777
scale,

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,7 @@ def test_sdpa_fusion(self, name, script_func):
180180

181181
# new_outputs = ort_run("optimized", model, inputs)
182182
# assert_allclose(new_outputs, original_outputs)
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

onnxscript/rewriter/pattern.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,7 @@ def replace_pattern(new_pattern):
14211421
self.remove_nodes,
14221422
self.graph_pre_visitor,
14231423
self.graph_post_visitor,
1424+
self.as_function,
14241425
)
14251426

14261427
return [replace_pattern(p) for p in self._target_pattern.commute()]
@@ -1502,21 +1503,23 @@ class RewriteRuleClassBase:
15021503
@classmethod
15031504
def rule(cls, *args, **kwargs):
15041505
instance = cls(*args, **kwargs)
1505-
setup = instance.setup if hasattr(instance, "setup") else None
1506-
cleanup = instance.cleanup if hasattr(instance, "cleanup") else None
15071506
return RewriteRule(
15081507
instance.pattern,
15091508
instance.rewrite,
15101509
instance.check,
15111510
name=instance.name,
15121511
remove_nodes=instance.remove_nodes,
1513-
graph_pre_visitor=setup,
1514-
graph_post_visitor=cleanup,
1512+
graph_pre_visitor=instance.setup,
1513+
graph_post_visitor=instance.cleanup,
1514+
as_function=instance.as_function,
15151515
)
15161516

1517-
def __init__(self, name: str | None = None, remove_nodes: bool = True) -> None:
1517+
def __init__(
1518+
self, name: str | None = None, remove_nodes: bool = True, as_function: bool = False
1519+
) -> None:
15181520
self.name = name or self.__class__.__name__
15191521
self.remove_nodes = remove_nodes
1522+
self.as_function = as_function
15201523

15211524
def pattern(self, op, *args, **kwargs):
15221525
raise NotImplementedError("Method 'pattern' must be implemented by derived class.")
@@ -1528,6 +1531,16 @@ def check(self, op, *args, **kwargs):
15281531
def rewrite(self, op, *args, **kwargs):
15291532
raise NotImplementedError("Method 'rewrite' must be implemented by derived class.")
15301533

1534+
def setup(self):
1535+
# Optional setup function that can be overridden by derived classes. Used to do
1536+
# per model/function initialization.
1537+
pass
1538+
1539+
def cleanup(self):
1540+
# Optional cleanup function that can be overridden by derived classes. Used to do
1541+
# per model/function cleanup.
1542+
pass
1543+
15311544

15321545
class RewriteRuleSet:
15331546
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:

0 commit comments

Comments
 (0)