@@ -37,14 +37,20 @@ struct Alibi {
37
37
const int col_idx_offset = col_idx_offset_ + (lane_id % 4 ) * 2 ;
38
38
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
39
39
#pragma unroll
40
- for (int nj = 0 ; nj < size<1 , 1 >(tensor); ++nj ) {
41
- const int col_idx_base = col_idx_offset + nj * 8 ;
40
+ for (int mi = 0 ; mi < size<0 , 1 >(tensor); ++mi ) {
41
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride ;
42
42
#pragma unroll
43
- for (int j = 0 ; j < size<1 , 0 >(tensor); ++j) {
44
- const int col_idx = col_idx_base + j;
43
+ for (int i = 0 ; i < size<0 , 0 >(tensor); ++i) {
44
+ const int row_idx = row_idx_base + i * 8 ;
45
+ const int col_idx_limit_right = std::min (max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
45
46
#pragma unroll
46
- for (int mi = 0 ; mi < size<0 >(tensor); ++mi) {
47
- tensor (mi, make_coord (j, nj)) += alibi_slope * col_idx;
47
+ for (int nj = 0 ; nj < size<1 , 1 >(tensor); ++nj) {
48
+ const int col_idx_base = col_idx_offset + nj * 8 ;
49
+ #pragma unroll
50
+ for (int j = 0 ; j < size<1 , 0 >(tensor); ++j) {
51
+ const int col_idx = col_idx_base + j;
52
+ tensor (make_coord (i, mi), make_coord (j, nj)) += ((col_idx == (col_idx_limit_right - 1 )) ? 0 : alibi_slope);
53
+ }
48
54
}
49
55
}
50
56
}
@@ -61,7 +67,7 @@ struct Alibi {
61
67
#pragma unroll
62
68
for (int j = 0 ; j < size<1 , 0 >(tensor); ++j) {
63
69
const int col_idx = col_idx_base + j;
64
- tensor (make_coord (i, mi), make_coord (j, nj)) -= alibi_slope * abs ( row_idx + max_seqlen_k - max_seqlen_q - col_idx);
70
+ tensor (make_coord (i, mi), make_coord (j, nj)) += ((( row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0 ) ? 0 : alibi_slope );
65
71
}
66
72
}
67
73
}
0 commit comments