diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index c66ca7553a..02c618211e 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -234,7 +234,13 @@ def _setup_attributes(self): - Configures pipeline stages for softmax, correction, and epilogue operations """ - self.kv_stage = 4 if self.q_dtype.width == 8 or self.q_stage == 1 else 3 + self.kv_stage = ( + 4 + if (self.q_dtype.width == 8 or self.q_stage == 1) + and self.head_dim_padded <= 128 + and self.head_dim_v_padded <= 128 + else 3 + ) self.acc_stage = 1 # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q.