Skip to content

Commit f7dbd17

Browse files
yeonbokCuriousPanCake
authored andcommitted
[GPU] Support GeLU Tanh for Phi-2 (openvinotoolkit#27213)
### Details: - Previously GeLU Tanh was supported only for x * (0.5 * (1 + tanh)) - Support pattern with (x * 0.5) * (1 + tanh)) too. ### Tickets: - 155576
1 parent cdf07f5 commit f7dbd17

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

src/common/transformations/src/transformations/common_optimizations/gelu_fusion.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "openvino/op/parameter.hpp"
2323
#include "openvino/op/power.hpp"
2424
#include "openvino/op/tanh.hpp"
25+
#include "openvino/pass/pattern/op/or.hpp"
2526
#include "openvino/pass/pattern/op/wrap_type.hpp"
2627
#include "transformations/utils/utils.hpp"
2728

@@ -280,9 +281,16 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
280281
auto add_1 = ov::pass::pattern::wrap_type<ov::op::v1::Add>({tanh, add_1_constant});
281282

282283
auto mul_2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
283-
auto mul_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_constant});
284284

285-
auto mul_3 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2});
285+
// x * (0.5 * (1 + tanh))
286+
auto mul_2_1 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_constant});
287+
auto mul_3_1 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2_1});
288+
289+
// (x * 0.5) * (1 + tanh)
290+
auto mul_2_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2_constant});
291+
auto mul_3_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_2});
292+
293+
auto mul_3 = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul_3_1, mul_3_2});
286294

287295
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
288296
auto& pattern_to_output = m.get_pattern_value_map();
@@ -298,7 +306,6 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
298306
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(mul_2_constant).get_node_shared_ptr());
299307
auto add_1_constant_value =
300308
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(add_1_constant).get_node_shared_ptr());
301-
302309
if (!pow_constant_value || !add_1_constant_value || !mul_0_constant_value || !mul_1_constant_value ||
303310
!mul_2_constant_value) {
304311
return false;
@@ -318,18 +325,17 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
318325
auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output, op::GeluApproximationMode::TANH);
319326

320327
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
321-
ov::copy_runtime_info(
322-
{
323-
pattern_to_output.at(pow).get_node_shared_ptr(),
324-
pattern_to_output.at(mul_0).get_node_shared_ptr(),
325-
pattern_to_output.at(mul_1).get_node_shared_ptr(),
326-
pattern_to_output.at(mul_2).get_node_shared_ptr(),
327-
pattern_to_output.at(mul_3).get_node_shared_ptr(),
328-
pattern_to_output.at(tanh).get_node_shared_ptr(),
329-
pattern_to_output.at(add_0).get_node_shared_ptr(),
330-
pattern_to_output.at(add_1).get_node_shared_ptr(),
331-
},
332-
gelu);
328+
329+
std::vector<std::shared_ptr<ov::Node>> pattern_nodes =
330+
{pow, mul_0, mul_1, tanh, add_0, add_1, mul_2_1, mul_2_2, mul_3_1, mul_3_2};
331+
std::vector<std::shared_ptr<ov::Node>> cp_rt_info_nodes;
332+
for (const auto& pattern_node : pattern_nodes) {
333+
if (pattern_to_output.count(pattern_node)) {
334+
cp_rt_info_nodes.push_back(pattern_to_output.at(pattern_node).get_node_shared_ptr());
335+
}
336+
}
337+
ov::copy_runtime_info(cp_rt_info_nodes, gelu);
338+
333339
ov::replace_node(m.get_match_root(), gelu);
334340
return true;
335341
};

src/common/transformations/tests/common_optimizations/gelu_fusion.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,44 @@ TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value) {
388388
}
389389
}
390390

391+
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value_2) {
392+
{
393+
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
394+
auto pow_constant =
395+
std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f + 1.0e-8f});
396+
auto pow = std::make_shared<ov::op::v1::Power>(input, pow_constant);
397+
auto mul_0_constant =
398+
std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
399+
auto mul_0 = std::make_shared<ov::op::v1::Multiply>(pow, mul_0_constant);
400+
auto add_0 = std::make_shared<ov::op::v1::Add>(input, mul_0);
401+
402+
auto mul_1_constant =
403+
std::make_shared<ov::op::v0::Constant>(element::f32,
404+
Shape{1},
405+
std::vector<float>{static_cast<float>(std::sqrt(2.0 / M_PI))});
406+
auto mul_1 = std::make_shared<ov::op::v1::Multiply>(add_0, mul_1_constant);
407+
408+
auto tanh = std::make_shared<ov::op::v0::Tanh>(mul_1);
409+
410+
auto add_1_constant = std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
411+
auto add_1 = std::make_shared<ov::op::v1::Add>(tanh, add_1_constant);
412+
413+
auto mul_2_constant = std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
414+
auto mul_2 = std::make_shared<ov::op::v1::Multiply>(input, mul_2_constant);
415+
416+
auto mul_3 = std::make_shared<ov::op::v1::Multiply>(add_1, mul_2);
417+
418+
model = std::make_shared<Model>(NodeVector{mul_3}, ParameterVector{input});
419+
manager.register_pass<ov::pass::GeluFusionWithTanh>();
420+
}
421+
422+
{
423+
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
424+
auto gelu = std::make_shared<ov::op::v7::Gelu>(data, op::GeluApproximationMode::TANH);
425+
model_ref = std::make_shared<Model>(NodeVector{gelu}, ParameterVector{data});
426+
}
427+
}
428+
391429
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_pow_value) {
392430
{
393431
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});

0 commit comments

Comments
 (0)