diff --git a/src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp b/src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp index 7b38bc8c10d141..b6b60cfad3aea2 100644 --- a/src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp +++ b/src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp @@ -73,9 +73,9 @@ std::vector shape_infer(const ScaledDotProductAttention* op, const auto& attention_mask_rank = attention_mask.rank(); if (attention_mask_rank.is_static() && attention_mask_rank != 0) { const auto& attention_mask_rank_len = attention_mask_rank.get_length(); - bool attention_mask_input_correctness = attention_mask_rank_len >= 2 && - DimType::merge(l_dim, l_dim, *(attention_mask.end() - 2)) && - DimType::merge(s_dim, s_dim, *(attention_mask.end() - 1)); + bool attention_mask_input_correctness = + attention_mask_rank_len >= 2 && DimType::broadcast_merge(l_dim, l_dim, *(attention_mask.end() - 2)) && + DimType::broadcast_merge(s_dim, s_dim, *(attention_mask.end() - 1)); if (attention_mask_rank_len >= 3) { attention_mask_input_correctness = attention_mask_input_correctness &&