|
29 | 29 |
|
30 | 30 |
|
31 | 31 | DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" |
32 | | -TEST_BWD = True |
| 32 | +TEST_BWD_ONLY = False |
33 | 33 | VERBOSE = True |
34 | 34 |
|
35 | 35 | # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) |
@@ -117,7 +117,7 @@ def test_flash_attn_output( |
117 | 117 | dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype |
118 | 118 | # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) |
119 | 119 | dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) |
120 | | - if dtype == torch.float8_e4m3fn or TEST_BWD: |
| 120 | + if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: |
121 | 121 | dv_vals = [d] |
122 | 122 | # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] |
123 | 123 | attention_chunk_vals = [0] |
@@ -236,7 +236,7 @@ def test_flash_attn_output( |
236 | 236 | # pack_gqa_vals = [False, True, None] |
237 | 237 | # SplitKV is not supported for hdim >= 192 |
238 | 238 | pack_gqa_vals = [False] |
239 | | - num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD else [1] |
| 239 | + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] |
240 | 240 | for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): |
241 | 241 | out, lse = flash_attn_func( |
242 | 242 | q, |
|
0 commit comments