diff --git a/src/common/transformations/include/transformations/op_conversions/einsum_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/einsum_decomposition.hpp index af9d73d6c47296..b6dcbb5fa59941 100644 --- a/src/common/transformations/include/transformations/op_conversions/einsum_decomposition.hpp +++ b/src/common/transformations/include/transformations/op_conversions/einsum_decomposition.hpp @@ -25,5 +25,8 @@ class TRANSFORMATIONS_API EinsumDecomposition; class ov::pass::EinsumDecomposition : public ov::pass::MatcherPass { public: OPENVINO_MATCHER_PASS_RTTI("EinsumDecomposition"); - EinsumDecomposition(); + EinsumDecomposition(bool check_const = false); + +private: + bool m_check_const; // store the flag }; diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 5ff07b5fa46ee5..bb18c82d0d0087 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -90,6 +90,7 @@ #include "transformations/op_conversions/convert_scatter_elements_to_scatter.hpp" #include "transformations/op_conversions/convert_subtract.hpp" #include "transformations/op_conversions/convert_ti_to_sequences.hpp" +#include "transformations/op_conversions/einsum_decomposition.hpp" #include "transformations/resolve_names_collisions.hpp" #include "transformations/smart_reshape/lstm_states_broadcast.hpp" #include "transformations/smart_reshape/matmul_sr.hpp" @@ -164,6 +165,13 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr REGISTER_PASS(manager, ConstantFolding) REGISTER_PASS(manager, Validate) + // EinsumDecomposition should be called after ConstantFolding + // for better performance and memory usage. + // ConstantFolding creates constant inputs to Einsum operations, + // which EinsumDecomposition can then decompose more efficiently with + // reduced memory consumption. + REGISTER_PASS(manager, EinsumDecomposition, true) + // FusedFilteringBoxesBySize transformation has the complex pattern // which can be affected by further transformations. So we have to // execute it at the beginning of the pipeline. Also, this pass resolves diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index dd5a36a25962e4..cd3aa72bd3750a 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -1291,7 +1291,7 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, /// 8. Transpose dimensions to match the layout required by the output subscript. /// 9. Replace the original Einsum node with the last node from the decomposed sub-graph, /// preserving the original node's name and runtime information. -ov::pass::EinsumDecomposition::EinsumDecomposition() { +ov::pass::EinsumDecomposition::EinsumDecomposition(bool check_const) : m_check_const(check_const) { MATCHER_SCOPE(EinsumDecomposition); auto einsum = ov::pass::pattern::wrap_type(); matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { @@ -1300,6 +1300,28 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { return false; } + if (m_check_const) { + // This optimization targets Einsum operations in transformer models + // where at least one input is constant. After ConstantFolding, + // weight matrices become constants enabling efficient decomposition. + // Optimized patterns: + // 1. Weight projections: einsum("abc,cd->abd", input, weight_matrix) - OPTIMIZED (constant weight) + // 2. Attention scores: einsum("aecd,abcd->acbe", key, query) - NOT OPTIMIZED (both variable) + // 3. Attention-value: einsum("acbe,aecd->abcd", attention_scores, value) - NOT OPTIMIZED (both variable) + // See: https://gist.github.com/Mohamed-Ashraf273/59eddcd120918cb0761ffa5020800d5d + bool has_const = false; + for (auto& input : einsum_node->input_values()) { + auto node_ptr = input.get_node_shared_ptr(); + auto constant_ptr = ov::as_type_ptr(node_ptr); + if (constant_ptr) { + has_const = true; + break; + } + } + if (!has_const) + return false; + } + // Parse the Einsum equation to get input and output subscripts auto equation = einsum_node->get_equation(); std::vector input_subscripts;