File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -82,6 +82,10 @@ def online_softmax(
8282 )
8383
8484 row_max_cur = utils .warp_reduce (row_max_cur , cute .arch .fmax , width = 4 )
85+ # Update row_max before changing row_max_cur to safe value for -inf
86+ row_max_prev = row_max [r ]
87+ row_max [r ] = row_max_cur
88+
8589 if cutlass .const_expr (check_inf ):
8690 row_max_cur = 0.0 if row_max_cur == - Float32 .inf else row_max_cur
8791
@@ -92,7 +96,6 @@ def online_softmax(
9296 acc_S_row_sum = utils .fadd_reduce (acc_S_row_exp , init_val = None , arch = arch )
9397 row_scale [r ] = 1.0
9498 else :
95- row_max_prev = row_max [r ]
9699 row_max_cur_scaled = row_max_cur * scale_log2
97100 acc_S_row_exp = utils .exp2f (acc_S_row * scale_log2 - row_max_cur_scaled )
98101 # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled)
@@ -102,7 +105,6 @@ def online_softmax(
102105 acc_S_row_exp , init_val = row_sum [r ] * row_scale [r ], arch = arch
103106 )
104107
105- row_max [r ] = row_max_cur
106108 row_sum [r ] = acc_S_row_sum
107109 acc_S_mn [r , None ].store (acc_S_row_exp )
108110
You can’t perform that action at this time.
0 commit comments