Skip to content

Commit db00abf

Browse files
update rope
1 parent f59fc21 commit db00abf

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr<ov::Model>& model)
8181
static std::shared_ptr<ov::Node> gen_chatglm_const() {
8282
using namespace ov::pass::pattern;
8383

84-
auto pred = value_matches("-1, batch, head_cnt, ndims/2, 1") || value_matches("0, 0, 0, ndims/2, 1") ||
85-
value_matches("1, -1, head_cnt, ndims/2, 1");
84+
auto pred = value_matches("-1, batch, head_cnt, ndims/2, 1") || value_matches("1, -1, head_cnt, ndims/2, 1") ||
85+
value_matches("0, 0, 0, ndims/2, 1");
8686
return wrap_type<v0::Constant>(pred);
8787
}
8888

@@ -696,16 +696,16 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
696696
std::shared_ptr<ov::Node> reshape0 = nullptr;
697697
if (support_2d_rope) {
698698
auto const_target_shape0 =
699-
pattern::wrap_type<v0::Constant>(pattern::value_matches("0, head_cnt_pos3, 0, ndims/2, 2"));
699+
pattern::wrap_type<v0::Constant>(pattern::value_matches("0, head_cnt, 0, ndims/2, 2"));
700700
reshape0 = pattern::wrap_type<v1::Reshape>({slice0 | var_split0->output(0), const_target_shape0},
701701
{{"special_zero", true}});
702702
} else {
703703
auto concat0 =
704704
pattern::wrap_type<v0::Concat>({seq_length, {-1}, {"head_cnt"}, {"ndims/2"}, {2}}, {{"axis", 0}});
705705
auto const_target_shape1 =
706-
pattern::wrap_type<v0::Constant>(pattern::value_matches("0, 0, head_cnt_pos4, ndims/2, 2"));
706+
pattern::wrap_type<v0::Constant>(pattern::value_matches("0, 0, head_cnt, ndims/2, 2"));
707707
auto const_target_shape2 =
708-
pattern::wrap_type<v0::Constant>(pattern::value_matches("seq_len, batch, head_cnt_pos5, ndims/2, 2"));
708+
pattern::wrap_type<v0::Constant>(pattern::value_matches("seq_len, batch, head_cnt, ndims/2, 2"));
709709
reshape0 = pattern::wrap_type<v1::Reshape>(
710710
{slice0 | var_split0->output(0), concat0 | const_target_shape1 | const_target_shape2});
711711
}
@@ -772,15 +772,14 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
772772
if (support_2d_rope) {
773773
// [batch, head_cnt, length, half_rotary_dims, 2]
774774
const_target_shape6 =
775-
pattern::wrap_type<v0::Constant>(pattern::value_matches("batch, head_cnt_pos6, seq_len, ndims") ||
776-
pattern::value_matches("0, head_cnt_pos7, 0, ndims"));
775+
pattern::wrap_type<v0::Constant>(pattern::value_matches("batch, head_cnt, seq_len, ndims") ||
776+
pattern::value_matches("0, head_cnt, 0, ndims"));
777777
reshape2 = pattern::wrap_type<v1::Reshape>({concat2, concat3 | const_target_shape6}, {{"special_zero", true}});
778778
} else {
779779
// [length, batch, head_cnt, half_rotary_dims, 2]
780-
auto const_target_shape7 =
781-
pattern::wrap_type<v0::Constant>(pattern::value_matches("0, 0, head_cnt_pos9, ndims"));
780+
auto const_target_shape7 = pattern::wrap_type<v0::Constant>(pattern::value_matches("0, 0, head_cnt, ndims"));
782781
const_target_shape6 =
783-
pattern::wrap_type<v0::Constant>(pattern::value_matches("seq_len, batch, head_cnt_pos8, ndims"));
782+
pattern::wrap_type<v0::Constant>(pattern::value_matches("seq_len, batch, head_cnt, ndims"));
784783
reshape2 = pattern::wrap_type<v1::Reshape>({concat2, concat3 | const_target_shape6 | const_target_shape7},
785784
{{"special_zero", true}});
786785
}

0 commit comments

Comments
 (0)