2222#include " openvino/op/parameter.hpp"
2323#include " openvino/op/power.hpp"
2424#include " openvino/op/tanh.hpp"
25+ #include " openvino/pass/pattern/op/or.hpp"
2526#include " openvino/pass/pattern/op/wrap_type.hpp"
2627#include " transformations/utils/utils.hpp"
2728
@@ -280,9 +281,16 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
280281 auto add_1 = ov::pass::pattern::wrap_type<ov::op::v1::Add>({tanh, add_1_constant});
281282
282283 auto mul_2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
283- auto mul_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_constant});
284284
285- auto mul_3 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2});
285+ // x * (0.5 * (1 + tanh))
286+ auto mul_2_1 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_constant});
287+ auto mul_3_1 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2_1});
288+
289+ // (x * 0.5) * (1 + tanh)
290+ auto mul_2_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2_constant});
291+ auto mul_3_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_2});
292+
293+ auto mul_3 = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul_3_1, mul_3_2});
286294
287295 ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
288296 auto & pattern_to_output = m.get_pattern_value_map ();
@@ -298,7 +306,6 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
298306 ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (mul_2_constant).get_node_shared_ptr ());
299307 auto add_1_constant_value =
300308 ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (add_1_constant).get_node_shared_ptr ());
301-
302309 if (!pow_constant_value || !add_1_constant_value || !mul_0_constant_value || !mul_1_constant_value ||
303310 !mul_2_constant_value) {
304311 return false ;
@@ -318,18 +325,17 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
318325 auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output, op::GeluApproximationMode::TANH);
319326
320327 gelu->set_friendly_name (m.get_match_root ()->get_friendly_name ());
321- ov::copy_runtime_info (
322- {
323- pattern_to_output.at (pow).get_node_shared_ptr (),
324- pattern_to_output.at (mul_0).get_node_shared_ptr (),
325- pattern_to_output.at (mul_1).get_node_shared_ptr (),
326- pattern_to_output.at (mul_2).get_node_shared_ptr (),
327- pattern_to_output.at (mul_3).get_node_shared_ptr (),
328- pattern_to_output.at (tanh).get_node_shared_ptr (),
329- pattern_to_output.at (add_0).get_node_shared_ptr (),
330- pattern_to_output.at (add_1).get_node_shared_ptr (),
331- },
332- gelu);
328+
329+ std::vector<std::shared_ptr<ov::Node>> pattern_nodes =
330+ {pow, mul_0, mul_1, tanh, add_0, add_1, mul_2_1, mul_2_2, mul_3_1, mul_3_2};
331+ std::vector<std::shared_ptr<ov::Node>> cp_rt_info_nodes;
332+ for (const auto & pattern_node : pattern_nodes) {
333+ if (pattern_to_output.count (pattern_node)) {
334+ cp_rt_info_nodes.push_back (pattern_to_output.at (pattern_node).get_node_shared_ptr ());
335+ }
336+ }
337+ ov::copy_runtime_info (cp_rt_info_nodes, gelu);
338+
333339 ov::replace_node (m.get_match_root (), gelu);
334340 return true ;
335341 };
0 commit comments