Skip to content

Commit cb264d5

Browse files
gramalingamjustinchuby
authored andcommitted
A couple of ort fusion fixes (microsoft#2136)
* Enable the use of SDPA fusions, along with undoing it when it does not lead to some subsequent final fusion (such as MHA or GQA). * Fix the use of constants in extracted functions from fusion. * Fix the use of Gelu instead of FastGelu in the new fusion introduced earlier today. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 2771ad1 commit cb264d5

File tree

10 files changed

+185
-19
lines changed

10 files changed

+185
-19
lines changed

noxfile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
"beartype==0.17.2",
1616
"expecttest==0.1.6",
1717
"hypothesis",
18-
'numpy==1.24.4; python_version<"3.9"',
19-
'numpy==1.26.4; python_version>="3.9"',
18+
"numpy",
2019
"packaging",
2120
"parameterized",
2221
'psutil; sys_platform != "win32"',

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44

55
import onnxscript.ir as ir
66
from onnxscript.ir.passes.common import shape_inference
7-
from onnxscript.optimizer import optimize, remove_unused_nodes
7+
from onnxscript.optimizer import optimize
8+
from onnxscript.rewriter import rewrite
9+
from onnxscript.rewriter.ort_fusions import (
10+
fused_matmul_rule_sets,
11+
# group_normalization_merge_silu,
12+
instance_to_group_normalization,
13+
softmax,
14+
)
815
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
916
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
1017
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
@@ -21,7 +28,6 @@
2128
# TODO: There are some potential redundancies below. Can be targeted for optimization
2229
# once we have robust fusion.
2330
def _pre_optimize(model: ir.Model) -> ir.Model:
24-
optimize(model)
2531
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
2632
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
2733
# incorporated in our optimizer.
@@ -30,7 +36,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
3036
return model
3137

3238

33-
def fuse_xformers(model: ir.Model) -> None:
39+
def fuse_xformers(model: ir.Model) -> ir.Model:
3440
model = _pre_optimize(model)
3541
fuse_rms_normalization(model)
3642
fuse_normalization(model)
@@ -40,9 +46,29 @@ def fuse_xformers(model: ir.Model) -> None:
4046
fuse_sdpa(model)
4147
fuse_mha(model)
4248
fuse_gelu(model)
43-
remove_unused_nodes(model)
49+
# Finally: inline any intermediate fusion functions introduced that were not
50+
# consumed by other fusions, and eliminate any remaining unused nodes.
51+
optimize(model)
52+
return model
53+
4454

55+
def optimize_for_ort(model: ir.Model, config_name: str | None = None) -> ir.Model:
56+
"""
57+
Optimize the model for ORT backend.
58+
59+
TODO: config_name is not used yet. It should be used to select the appropriate
60+
optimization configuration (for an EP). Currently, a default implementation is used.
61+
62+
Args:
63+
model: The model to optimize.
64+
config_name: The name of the configuration to use for optimization.
65+
Typically it identifies the Execution Provider (EP) to optimize for.
66+
If None, the default configuration will be used.
67+
68+
Returns:
69+
The optimized model.
70+
"""
4571

46-
def optimize_for_ort(model: ir.Model) -> None:
47-
# TODO(rama): Include the other optimizations
4872
fuse_xformers(model)
73+
rewrite(model, ORT_PATTERN_REWRITE_RULES)
74+
return model

onnxscript/rewriter/ort_fusions/_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def ort_run(model_name: str, model, inputs):
3232
return ort_outputs
3333

3434

35-
def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2):
35+
def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
3636
for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)):
3737
try:
3838
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import onnxscript.optimizer
8+
from onnxscript.rewriter.ort_fusions._core import fuse_xformers
9+
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
10+
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
11+
12+
13+
class TestFuseXformers(unittest.TestCase):
14+
def test_fuse_xformers(self):
15+
test = smollm_test_1()
16+
model = test.get_onnx_model()
17+
onnxscript.optimizer.optimize(model)
18+
inputs = test.get_ort_inputs()
19+
original_outputs = ort_run("original", model, inputs)
20+
model = fuse_xformers(model)
21+
new_outputs = ort_run("optimized", model, inputs)
22+
assert_allclose(new_outputs, original_outputs)
23+
24+
25+
if __name__ == "__main__":
26+
unittest.main()

onnxscript/rewriter/ort_fusions/gelu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def pattern(self, op, x):
2525
return result
2626

2727
def rewrite(self, op, x):
28-
return op.Gelu(x, _domain="com.microsoft")
28+
return op.FastGelu(x, _domain="com.microsoft")
2929

3030

3131
_rule = GeluTanhFusion.rule()

onnxscript/rewriter/ort_fusions/gelu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def gelu_model(x):
4747
remove_unused_nodes(model)
4848

4949
self.assertEqual(len(model.graph), 1)
50-
self.assertEqual(model.graph.node(0).op_type, "Gelu")
50+
self.assertEqual(model.graph.node(0).op_type, "FastGelu")
5151

5252
optimized_output = test_utils.ort_run("Optimized", model, input)
5353
test_utils.assert_allclose(original_output, optimized_output)

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def check(self, op, x, scale, epsilon, compute_dtype, target_dtype):
7171
def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
7272
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
7373
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
74-
# No need to use com.microsoft domain here.
74+
# No need to use com.microsoft domain here; but this is a custom op in ORT.
7575
return op.SimplifiedLayerNormalization(
7676
x,
7777
scale,

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010

1111
class SDPA(pattern.RewriteRuleClassBase):
12-
def __init__(self, name: str, *, use_mask: bool, pre_scale: bool):
13-
super().__init__(name=name)
12+
def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool):
13+
super().__init__(name=name, as_function=True)
1414
self._use_mask = use_mask
1515
self._pre_scale = pre_scale
1616

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,7 @@ def test_sdpa_fusion(self, name, script_func):
180180

181181
# new_outputs = ort_run("optimized", model, inputs)
182182
# assert_allclose(new_outputs, original_outputs)
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

onnxscript/rewriter/pattern.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,7 @@ def replace_pattern(new_pattern):
14211421
self.remove_nodes,
14221422
self.graph_pre_visitor,
14231423
self.graph_post_visitor,
1424+
self.as_function,
14241425
)
14251426

14261427
return [replace_pattern(p) for p in self._target_pattern.commute()]
@@ -1502,21 +1503,23 @@ class RewriteRuleClassBase:
15021503
@classmethod
15031504
def rule(cls, *args, **kwargs):
15041505
instance = cls(*args, **kwargs)
1505-
setup = instance.setup if hasattr(instance, "setup") else None
1506-
cleanup = instance.cleanup if hasattr(instance, "cleanup") else None
15071506
return RewriteRule(
15081507
instance.pattern,
15091508
instance.rewrite,
15101509
instance.check,
15111510
name=instance.name,
15121511
remove_nodes=instance.remove_nodes,
1513-
graph_pre_visitor=setup,
1514-
graph_post_visitor=cleanup,
1512+
graph_pre_visitor=instance.setup,
1513+
graph_post_visitor=instance.cleanup,
1514+
as_function=instance.as_function,
15151515
)
15161516

1517-
def __init__(self, name: str | None = None, remove_nodes: bool = True) -> None:
1517+
def __init__(
1518+
self, name: str | None = None, remove_nodes: bool = True, as_function: bool = False
1519+
) -> None:
15181520
self.name = name or self.__class__.__name__
15191521
self.remove_nodes = remove_nodes
1522+
self.as_function = as_function
15201523

15211524
def pattern(self, op, *args, **kwargs):
15221525
raise NotImplementedError("Method 'pattern' must be implemented by derived class.")
@@ -1528,6 +1531,114 @@ def check(self, op, *args, **kwargs):
15281531
def rewrite(self, op, *args, **kwargs):
15291532
raise NotImplementedError("Method 'rewrite' must be implemented by derived class.")
15301533

1534+
def setup(self):
1535+
# Optional setup function that can be overridden by derived classes. Used to do
1536+
# per model/function initialization.
1537+
pass
1538+
1539+
def cleanup(self):
1540+
# Optional cleanup function that can be overridden by derived classes. Used to do
1541+
# per model/function cleanup.
1542+
pass
1543+
1544+
1545+
def _copy_for_function(
1546+
inputs: Sequence[ir.Value | None], nodes: Sequence[ir.Node], outputs: Sequence[ir.Value]
1547+
):
1548+
"""Utility function to extract a subgraph out as a function."""
1549+
value_map: dict[ir.Value, ir.Value] = {}
1550+
function_inputs: list[ir.Value] = []
1551+
constant_nodes: list[ir.Node] = []
1552+
for input in inputs:
1553+
# Create a function input (formal-parameter value) to represent this value:
1554+
new_value = (
1555+
ir.Value(
1556+
name=input.name,
1557+
shape=input.shape,
1558+
type=input.type,
1559+
doc_string=input.doc_string,
1560+
)
1561+
if input
1562+
else ir.Value() # dummy parameter for a None input
1563+
)
1564+
if input is not None:
1565+
value_map[input] = new_value
1566+
function_inputs.append(new_value)
1567+
1568+
def copy_value(value: ir.Value | None) -> ir.Value | None:
1569+
if value is None:
1570+
return None
1571+
if value not in value_map:
1572+
const_value = value.const_value
1573+
if const_value is not None:
1574+
# create a Constant node to represent the value
1575+
value_attr = ir.AttrTensor("value", const_value)
1576+
const_node = ir.Node("", "Constant", [], [value_attr])
1577+
constant_nodes.append(const_node)
1578+
value_map[value] = result = const_node.outputs[0]
1579+
return result
1580+
raise ValueError(f"Value {value} not found in value_map.")
1581+
return value_map[value]
1582+
1583+
def copy_attr_value(attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr:
1584+
if not isinstance(attr, ir.Attr):
1585+
# No need to support this currently, as rewriting inside a function is
1586+
# not used, as it has several challenges.
1587+
raise NotImplementedError("RefAttr not supported.")
1588+
if attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}:
1589+
# No need to support this currently, as rewriting control-flow constructs
1590+
# is not used and has several challenges.
1591+
raise NotImplementedError("Graph attributes not supported.")
1592+
# Primitive attributes are immutable by design and can be shared.
1593+
return attr
1594+
1595+
def copy_node(node: ir.Node) -> ir.Node:
1596+
new_inputs = [copy_value(v) for v in node.inputs]
1597+
new_attributes = [copy_attr_value(v) for v in node.attributes.values()]
1598+
new_node = ir.Node(
1599+
node.domain,
1600+
node.op_type,
1601+
new_inputs,
1602+
new_attributes,
1603+
overload=node.overload,
1604+
num_outputs=len(node.outputs),
1605+
graph=None,
1606+
name=node.name,
1607+
doc_string=node.doc_string, # type: ignore
1608+
metadata_props=node.metadata_props.copy(),
1609+
)
1610+
new_outputs = new_node.outputs
1611+
for i, output in enumerate(node.outputs):
1612+
value_map[output] = new_outputs[i]
1613+
if output.name is not None:
1614+
new_outputs[i].name = output.name
1615+
return new_node
1616+
1617+
function_nodes = [copy_node(node) for node in nodes]
1618+
function_outputs = [copy_value(v) for v in outputs]
1619+
return (function_inputs, constant_nodes + function_nodes, function_outputs)
1620+
1621+
1622+
def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:
1623+
"""Get a new overload for the given domain and name.
1624+
1625+
Args:
1626+
model: The model to which the new overload will be added.
1627+
domain: The domain of the new overload.
1628+
name: The opname of the new overload.
1629+
1630+
Returns:
1631+
The new overload name.
1632+
"""
1633+
existing_functions = model.functions
1634+
# Just a simple implementation for now
1635+
overload = 1
1636+
while True:
1637+
overload_name = str(overload)
1638+
if (domain, name, overload_name) not in existing_functions:
1639+
return overload_name
1640+
overload += 1
1641+
15311642

15321643
class RewriteRuleSet:
15331644
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:

0 commit comments

Comments
 (0)