@@ -81,8 +81,8 @@ bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr<ov::Model>& model)
8181static std::shared_ptr<ov::Node> gen_chatglm_const () {
8282 using namespace ov ::pass::pattern;
8383
84- auto pred = value_matches (" -1, batch, head_cnt, ndims/2, 1" ) || value_matches (" 0, 0, 0 , ndims/2, 1" ) ||
85- value_matches (" 1, -1, head_cnt , ndims/2, 1" );
84+ auto pred = value_matches (" -1, batch, head_cnt, ndims/2, 1" ) || value_matches (" 1, -1, head_cnt , ndims/2, 1" ) ||
85+ value_matches (" 0, 0, 0 , ndims/2, 1" );
8686 return wrap_type<v0::Constant>(pred);
8787}
8888
@@ -696,16 +696,16 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
696696 std::shared_ptr<ov::Node> reshape0 = nullptr ;
697697 if (support_2d_rope) {
698698 auto const_target_shape0 =
699- pattern::wrap_type<v0::Constant>(pattern::value_matches (" 0, head_cnt_pos3 , 0, ndims/2, 2" ));
699+ pattern::wrap_type<v0::Constant>(pattern::value_matches (" 0, head_cnt , 0, ndims/2, 2" ));
700700 reshape0 = pattern::wrap_type<v1::Reshape>({slice0 | var_split0->output (0 ), const_target_shape0},
701701 {{" special_zero" , true }});
702702 } else {
703703 auto concat0 =
704704 pattern::wrap_type<v0::Concat>({seq_length, {-1 }, {" head_cnt" }, {" ndims/2" }, {2 }}, {{" axis" , 0 }});
705705 auto const_target_shape1 =
706- pattern::wrap_type<v0::Constant>(pattern::value_matches (" 0, 0, head_cnt_pos4 , ndims/2, 2" ));
706+ pattern::wrap_type<v0::Constant>(pattern::value_matches (" 0, 0, head_cnt , ndims/2, 2" ));
707707 auto const_target_shape2 =
708- pattern::wrap_type<v0::Constant>(pattern::value_matches (" seq_len, batch, head_cnt_pos5 , ndims/2, 2" ));
708+ pattern::wrap_type<v0::Constant>(pattern::value_matches (" seq_len, batch, head_cnt , ndims/2, 2" ));
709709 reshape0 = pattern::wrap_type<v1::Reshape>(
710710 {slice0 | var_split0->output (0 ), concat0 | const_target_shape1 | const_target_shape2});
711711 }
@@ -772,15 +772,14 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
772772 if (support_2d_rope) {
773773 // [batch, head_cnt, length, half_rotary_dims, 2]
774774 const_target_shape6 =
775- pattern::wrap_type<v0::Constant>(pattern::value_matches (" batch, head_cnt_pos6 , seq_len, ndims" ) ||
776- pattern::value_matches (" 0, head_cnt_pos7 , 0, ndims" ));
775+ pattern::wrap_type<v0::Constant>(pattern::value_matches (" batch, head_cnt , seq_len, ndims" ) ||
776+ pattern::value_matches (" 0, head_cnt , 0, ndims" ));
777777 reshape2 = pattern::wrap_type<v1::Reshape>({concat2, concat3 | const_target_shape6}, {{" special_zero" , true }});
778778 } else {
779779 // [length, batch, head_cnt, half_rotary_dims, 2]
780- auto const_target_shape7 =
781- pattern::wrap_type<v0::Constant>(pattern::value_matches (" 0, 0, head_cnt_pos9, ndims" ));
780+ auto const_target_shape7 = pattern::wrap_type<v0::Constant>(pattern::value_matches (" 0, 0, head_cnt, ndims" ));
782781 const_target_shape6 =
783- pattern::wrap_type<v0::Constant>(pattern::value_matches (" seq_len, batch, head_cnt_pos8 , ndims" ));
782+ pattern::wrap_type<v0::Constant>(pattern::value_matches (" seq_len, batch, head_cnt , ndims" ));
784783 reshape2 = pattern::wrap_type<v1::Reshape>({concat2, concat3 | const_target_shape6 | const_target_shape7},
785784 {{" special_zero" , true }});
786785 }
0 commit comments