Skip to content

Commit f7e7e7f

Browse files
Evgenya Nugmanovaallnes
authored andcommitted
[SDPA] relax restriction on L from query and attention_mask being equal (openvinotoolkit#24745)
### Details: - *relax restriction on L from query and attention_mask being equal* ### Tickets: - *CVS-129000*
1 parent 8c669d0 commit f7e7e7f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ std::vector<TRShape> shape_infer(const ScaledDotProductAttention* op,
7373
const auto& attention_mask_rank = attention_mask.rank();
7474
if (attention_mask_rank.is_static() && attention_mask_rank != 0) {
7575
const auto& attention_mask_rank_len = attention_mask_rank.get_length();
76-
bool attention_mask_input_correctness = attention_mask_rank_len >= 2 &&
77-
DimType::merge(l_dim, l_dim, *(attention_mask.end() - 2)) &&
78-
DimType::merge(s_dim, s_dim, *(attention_mask.end() - 1));
76+
bool attention_mask_input_correctness =
77+
attention_mask_rank_len >= 2 && DimType::broadcast_merge(l_dim, l_dim, *(attention_mask.end() - 2)) &&
78+
DimType::broadcast_merge(s_dim, s_dim, *(attention_mask.end() - 1));
7979
if (attention_mask_rank_len >= 3) {
8080
attention_mask_input_correctness =
8181
attention_mask_input_correctness &&

0 commit comments

Comments
 (0)