Skip to content

Commit efbbaf4

Browse files
authored
also modify alibi.h
1 parent f5ce6ee commit efbbaf4

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

csrc/flash_attn/src/alibi.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,20 @@ struct Alibi {
3737
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
3838
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
3939
#pragma unroll
40-
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
41-
const int col_idx_base = col_idx_offset + nj * 8;
40+
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
41+
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
4242
#pragma unroll
43-
for (int j = 0; j < size<1, 0>(tensor); ++j) {
44-
const int col_idx = col_idx_base + j;
43+
for (int i = 0; i < size<0, 0>(tensor); ++i) {
44+
const int row_idx = row_idx_base + i * 8;
45+
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
4546
#pragma unroll
46-
for (int mi = 0; mi < size<0>(tensor); ++mi) {
47-
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
47+
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
48+
const int col_idx_base = col_idx_offset + nj * 8;
49+
#pragma unroll
50+
for (int j = 0; j < size<1, 0>(tensor); ++j) {
51+
const int col_idx = col_idx_base + j;
52+
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);
53+
}
4854
}
4955
}
5056
}
@@ -61,7 +67,7 @@ struct Alibi {
6167
#pragma unroll
6268
for (int j = 0; j < size<1, 0>(tensor); ++j) {
6369
const int col_idx = col_idx_base + j;
64-
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
70+
tensor(make_coord(i, mi), make_coord(j, nj)) += (((row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0) ? 0 : alibi_slope);
6571
}
6672
}
6773
}

0 commit comments

Comments
 (0)