Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
#include "transformations/op_conversions/softmax_decomposition.hpp"
#include "transformations/op_conversions/softsign_decomposition.hpp"
#include "transformations/op_conversions/unique_decomposition.hpp"
#include "transformations/symbolic_transformations/dereshape_matmul.hpp"
#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"

bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model>& f) {
Expand Down Expand Up @@ -188,6 +189,9 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
ADD_MATCHER(decomp, UniqueDecomposition)
decomp->set_name("ov::pass::CommonDecompositions");

REGISTER_PASS(manager, NopElimination, true)
REGISTER_PASS(manager, DeReshapeMatMul)
REGISTER_PASS(manager, DeReshapeFullyConnected)
// CF is required after all decompositions
REGISTER_PASS(manager, ConstantFolding)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ static std::shared_ptr<ov::Node> gen_chatglm_const() {
using namespace pattern;

auto pred = value_matches("-1, head_cnt, 1, ndims/2, 1") || value_matches("1, -1, head_cnt, ndims/2, 1") ||
value_matches("0, 0, 0, ndims/2, 1");
value_matches("0, 0, 0, ndims/2, 1") || value_matches("-1, batch, head_cnt, ndims/2, 1");
return wrap_type<v0::Constant>(pred);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,32 @@
#include "openvino/op/ops.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/common_optimizations/common_optimizations.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov;
using namespace std;

TEST(nop_elimination, shared_const_einsum_after_common_optimizations) {
auto const_data = op::v0::Constant::create(element::f32, Shape{2, 2}, {1, 2, 3, 4});

auto einsum1 = std::make_shared<op::v7::Einsum>(OutputVector{const_data}, "ii->i");
auto einsum2 = std::make_shared<op::v7::Einsum>(OutputVector{const_data}, "ii->i");

auto model = std::make_shared<ov::Model>(OutputVector{einsum1, einsum2}, ov::ParameterVector{});

ov::pass::Manager pass_manager;
pass_manager.register_pass<ov::pass::CommonOptimizations>();
pass_manager.run_passes(model);

auto einsum1_const = einsum1->input_value(0).get_node_shared_ptr();
auto einsum2_const = einsum2->input_value(0).get_node_shared_ptr();

ASSERT_EQ(einsum1_const, einsum2_const);
}

TEST(nop_elimination, eliminate_convert) {
std::shared_ptr<ov::Model> f;
{
Expand Down
Loading