Skip to content

Commit 88e1cc9

Browse files
authored
Update mask.h
1 parent efbbaf4 commit 88e1cc9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

csrc/flash_attn/src/mask.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ struct Mask {
181181
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);
182182

183183
} else {
184-
tensor(make_coord(i, mi), make_coord(j, nj)) += (((row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0) ? 0 : alibi_slope);
184+
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == row_idx) ? 0 : alibi_slope);
185185

186186
}
187187
}

0 commit comments

Comments
 (0)