Skip to content

Commit 13696f2

Browse files
authored
[Cute] update row_max before safe overwrite for online_softmax (#2174)
* update row_max before safe overwrite * move up row_max_prev
1 parent 4894657 commit 13696f2

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

flash_attn/cute/softmax.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def online_softmax(
8282
)
8383

8484
row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4)
85+
# Update row_max before changing row_max_cur to safe value for -inf
86+
row_max_prev = row_max[r]
87+
row_max[r] = row_max_cur
88+
8589
if cutlass.const_expr(check_inf):
8690
row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur
8791

@@ -92,7 +96,6 @@ def online_softmax(
9296
acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
9397
row_scale[r] = 1.0
9498
else:
95-
row_max_prev = row_max[r]
9699
row_max_cur_scaled = row_max_cur * scale_log2
97100
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
98101
# row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled)
@@ -102,7 +105,6 @@ def online_softmax(
102105
acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
103106
)
104107

105-
row_max[r] = row_max_cur
106108
row_sum[r] = acc_S_row_sum
107109
acc_S_mn[r, None].store(acc_S_row_exp)
108110

0 commit comments

Comments
 (0)