Skip to content

Commit 8750107

Browse files
committed
revert to fwd test defaults
1 parent b508bda commit 8750107

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/cute/test_flash_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
32-
TEST_BWD = True
32+
TEST_BWD_ONLY = False
3333
VERBOSE = True
3434

3535
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@@ -117,7 +117,7 @@ def test_flash_attn_output(
117117
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
118118
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
119119
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
120-
if dtype == torch.float8_e4m3fn or TEST_BWD:
120+
if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY:
121121
dv_vals = [d]
122122
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0]
123123
attention_chunk_vals = [0]
@@ -236,7 +236,7 @@ def test_flash_attn_output(
236236
# pack_gqa_vals = [False, True, None]
237237
# SplitKV is not supported for hdim >= 192
238238
pack_gqa_vals = [False]
239-
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD else [1]
239+
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
240240
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
241241
out, lse = flash_attn_func(
242242
q,

0 commit comments

Comments
 (0)