Skip to content

Commit 9fe53f2

Browse files
gramalingambmehta001
authored andcommitted
Add Gelu Tanh fusion rule (microsoft#2132)
Add Gelu Tanh fusion rule
1 parent 2c80792 commit 9fe53f2

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import math
6+
7+
from onnxscript import ir
8+
from onnxscript.rewriter import pattern
9+
10+
_sqrt_two_over_pi = math.sqrt(2.0 / math.pi)
11+
12+
13+
class GeluTanhFusion(pattern.RewriteRuleClassBase):
14+
def pattern(self, op, x):
15+
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
16+
t1 = op.Pow(x, 3)
17+
t2 = op.Mul(0.044715, t1)
18+
t3 = op.Add(x, t2)
19+
20+
t4 = op.Mul(_sqrt_two_over_pi, t3)
21+
t5 = op.Tanh(t4)
22+
t6 = op.Add(t5, 1)
23+
t7 = op.Mul(x, t6)
24+
result = op.Mul(0.5, t7)
25+
return result
26+
27+
def rewrite(self, op, x):
28+
return op.Gelu(x, _domain="com.microsoft")
29+
30+
31+
_rule = GeluTanhFusion.rule()
32+
33+
gelu_rules = pattern.RewriteRuleSet([_rule])
34+
35+
36+
def fuse_gelu(model: ir.Model) -> None:
37+
gelu_rules.apply_to_model(model)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import math
5+
import unittest
6+
7+
import numpy as np
8+
9+
import onnxscript.ir as ir
10+
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
11+
from onnxscript import FLOAT, script
12+
from onnxscript import opset18 as op
13+
from onnxscript.optimizer import optimize, remove_unused_nodes
14+
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
15+
16+
17+
class GeluFusionTest(unittest.TestCase):
18+
def test_gelu_fusion(self):
19+
_sqrt_two_over_pi = math.sqrt(2.0 / math.pi)
20+
21+
@script()
22+
def gelu_model(x):
23+
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
24+
t1 = op.Pow(x, 3)
25+
t2 = op.Mul(0.044715, t1)
26+
t3 = op.Add(x, t2)
27+
28+
t4 = op.Mul(_sqrt_two_over_pi, t3)
29+
t5 = op.Tanh(t4)
30+
t6 = op.Add(t5, 1)
31+
t7 = op.Mul(x, t6)
32+
result = op.Mul(0.5, t7)
33+
return result
34+
35+
model_proto = gelu_model.to_model_proto(
36+
input_types=[FLOAT[10]], output_types=[FLOAT[10]]
37+
)
38+
model = ir.serde.deserialize_model(model_proto)
39+
40+
# Eliminate redundant CastLike ops:
41+
optimize(model)
42+
43+
input = {"x": np.random.randn(10).astype(np.float32)}
44+
original_output = test_utils.ort_run("Original", model, input)
45+
46+
fuse_gelu(model)
47+
remove_unused_nodes(model)
48+
49+
self.assertEqual(len(model.graph), 1)
50+
self.assertEqual(model.graph.node(0).op_type, "Gelu")
51+
52+
optimized_output = test_utils.ort_run("Optimized", model, input)
53+
test_utils.assert_allclose(original_output, optimized_output)
54+
55+
56+
if __name__ == "__main__":
57+
unittest.main()

0 commit comments

Comments
 (0)