Skip to content

Commit a967cfe

Browse files
committed
benchmarks
1 parent 66ae0bc commit a967cfe

File tree

4 files changed

+141
-71
lines changed

4 files changed

+141
-71
lines changed

hopper/benchmark_attn.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import cudnn
1414
except ImportError:
1515
cudnn = None
16-
# cudnn = None
16+
cudnn = None
1717

1818
Timing = NamedTuple('timing', [('mean', float)])
1919

@@ -24,8 +24,8 @@
2424
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
2525
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
2626
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
27-
# from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3
28-
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
27+
from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3
28+
# from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
2929

3030
from triton.testing import do_bench
3131

@@ -226,21 +226,22 @@ def run(*args, **kwargs):
226226
softcap = 0.0
227227
V_colmajor = False
228228
deterministic = False
229-
batch_size = 2
229+
batch_size = 1
230230
# seqlen = 2048
231-
seqlen = 8192
231+
# seqlen = 8192
232+
seqlen = 2048 * 8
232233
# seqlen = 4096
233234
# seqlen = 2047
234-
dim = 2048
235-
# headdim = 128
235+
dim = 128 * 16
236+
headdim = 128
236237
# headdim = 64
237-
headdim = 256
238+
# headdim = 256
238239
# for headdim in [64, 128, 256]:
239240
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
240241
# bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
241242
# bs_seqlen_vals = [(32, 512), (16, 1024)]
242243
# bs_seqlen_vals = [(2, 64 * 132)]
243-
bs_seqlen_vals = [(2, 8192)]
244+
bs_seqlen_vals = [(1, 8192 * 2)]
244245
# bs_seqlen_vals = [(1, 16 * 1024)]
245246
time_f = {}
246247
time_b = {}
@@ -272,8 +273,10 @@ def run(*args, **kwargs):
272273
window_size = (-1, -1)
273274
# window_size = (seqlen // 2 - 1, 0)
274275
pack_gqa = None
276+
# pack_gqa = True
275277
# seqlen_q = 64
276278
seqlen_q = seqlen
279+
# seqlen_q = 1
277280
leftpad_k = None
278281
# leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32)
279282
q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)

hopper/benchmark_flash_attention_fp8.py

Lines changed: 116 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import cudnn
3030
except ImportError:
3131
cudnn = None
32+
cudnn = None
3233

3334

3435
def convert_to_cudnn_type(torch_type):
@@ -198,14 +199,22 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
198199
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
199200
return output.to(dtype=qkv.dtype)
200201

201-
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
202+
def flops(batch, q_seqlen, seqlen, headdim, nheads, causal, mode="fwd"):
202203
assert mode in ["fwd", "bwd", "fwd_bwd"]
203-
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
204+
f = 4 * batch * q_seqlen * seqlen * nheads * headdim // (2 if causal else 1)
204205
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
205206

206207
def efficiency(flop, time):
207208
return (flop / time / 10**12) if not math.isnan(time) else 0.0
208209

210+
def data_size(batch, q_seqlen, seqlen, headdim, nheads, nkvheads, nbytes, mode="fwd"):
211+
assert mode in ["fwd"]
212+
d_size = batch * nbytes * headdim * (q_seqlen * nheads * 2 + seqlen * nkvheads * 2)
213+
return d_size
214+
215+
def mem_bw(nbytes, time):
216+
return (nbytes / time / 1024 / 1024 / 1024 / 1024) if not math.isnan(time) else 0.0
217+
209218
def time_fwd(func, *args, **kwargs):
210219
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
211220
time_f = benchmark_forward(func, *args, **kwargs)
@@ -216,30 +225,50 @@ def time_fwd(func, *args, **kwargs):
216225

217226
repeats = 30
218227
device = 'cuda'
219-
# dtype = torch.float16
228+
dtype = torch.float16
220229
dtype = torch.float8_e4m3fn
230+
is_gqa = True
221231

232+
# For prefill
233+
q_seqlen_val = None
222234
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
223-
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
235+
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
236+
bs_seqlen_vals = [(1, 2048 * 8)]
237+
# bs_seqlen_vals = [(32, 8192), (32, 2048)]
238+
239+
# For decode
240+
# q_seqlen_val = 1
241+
# q_seqlen_val = 4
242+
# bs_seqlen_vals = [(1, 128)]
224243
# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2)]
225244
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
226-
causal_vals = [False, True]
227-
headdim_vals = [64, 128, 256]
228-
dim = 2048
245+
# bs_seqlen_vals = [(32, 8192*4), (32, 8192*2), (32, 8192), (32, 4096), (32, 2048), (32, 1024), (64, 8192*2), (128, 8192), (128, 4096), (128, 2048), (128, 1024)]
246+
247+
248+
# causal_vals = [False, True]
249+
causal_vals = [True]
250+
# headdim_vals = [64, 128, 256]
251+
headdim_vals = [128]
252+
dim = 128 * 16
229253
# dim = 256
230254
dropout_p = 0.0
231255

232-
methods = (["Pytorch", "Flash3"]
233-
+ (["cuDNN"] if cudnn is not None else [])
234-
# + (["Triton"] if attention_triton is not None else [])
235-
# + (["xformers.c"] if xops is not None else [])
236-
# + (["xformers.f"] if xops is not None else [])
237-
)
256+
# scaling_recipe = 1
257+
scaling_recipe = 0
258+
259+
methods = (["Flash3"])
260+
# methods = (["Pytorch", "Flash3"]
261+
# + (["cuDNN"] if cudnn is not None else [])
262+
# # + (["Triton"] if attention_triton is not None else [])
263+
# # + (["xformers.c"] if xops is not None else [])
264+
# # + (["xformers.f"] if xops is not None else [])
265+
# )
238266

239267
time_f = {}
240268
time_b = {}
241269
time_f_b = {}
242270
speed_f = {}
271+
mem_bw_f = {}
243272
speed_b = {}
244273
speed_f_b = {}
245274
for causal in causal_vals:
@@ -248,55 +277,89 @@ def time_fwd(func, *args, **kwargs):
248277
torch.cuda.empty_cache()
249278
config = (causal, headdim, batch_size, seqlen)
250279
nheads = dim // headdim
251-
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False) for _ in range(3)]
280+
nkvheads = 1 if is_gqa else nheads
281+
if q_seqlen_val is not None:
282+
q_seqlen = q_seqlen_val
283+
q = torch.rand(batch_size, q_seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False)
284+
else:
285+
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False)
286+
q_seqlen = seqlen
287+
k, v = [torch.randn(batch_size, seqlen, nkvheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False) for _ in range(2)]
252288

253-
qkv = torch.stack([q, k, v], dim=2)
254-
qkv = qkv.to(torch.bfloat16)
255-
f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
256-
time_f[config, "Pytorch"] = f
257-
res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
258-
259-
if attention_triton is not None:
260-
q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
261-
k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
262-
v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
263-
scale = 1 / math.sqrt(headdim)
264-
f = time_fwd(
265-
attention_triton, q_transposed, k_transposed, v_transposed,
266-
causal, scale, repeats=5, verbose=False, desc='Triton'
267-
)
268-
f = time_fwd(
269-
attention_triton, q_transposed, k_transposed, v_transposed,
270-
causal, scale, repeats=repeats, verbose=False, desc='Triton'
271-
)
272-
time_f[config, "Triton"] = f
273-
res = attention_triton(
274-
q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
275-
causal, scale
276-
).half().transpose(1, 2)
277-
torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
289+
# qkv = torch.stack([q, k, v], dim=2)
290+
# qkv = qkv.to(torch.bfloat16)
291+
# f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
292+
# time_f[config, "Pytorch"] = f
293+
# res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
294+
295+
# if attention_triton is not None:
296+
# q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
297+
# k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
298+
# v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
299+
# scale = 1 / math.sqrt(headdim)
300+
# f = time_fwd(
301+
# attention_triton, q_transposed, k_transposed, v_transposed,
302+
# causal, scale, repeats=5, verbose=False, desc='Triton'
303+
# )
304+
# f = time_fwd(
305+
# attention_triton, q_transposed, k_transposed, v_transposed,
306+
# causal, scale, repeats=repeats, verbose=False, desc='Triton'
307+
# )
308+
# time_f[config, "Triton"] = f
309+
# res = attention_triton(
310+
# q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
311+
# causal, scale
312+
# ).half().transpose(1, 2)
313+
# torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
278314

279315
# out = torch.empty_like(q)
280316
q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
281317
softmax_scale = q.shape[-1] ** (-0.5)
282-
descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')
283-
descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')
284-
descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')
318+
if scaling_recipe == 0:
319+
q_descale = torch.tensor([[1.0] * nkvheads] * batch_size, dtype=torch.float32, device='cuda')
320+
k_descale = torch.tensor([[1.0] * nkvheads] * batch_size, dtype=torch.float32, device='cuda')
321+
v_descale = torch.tensor([[1.0] * nkvheads] * batch_size, dtype=torch.float32, device='cuda')
322+
elif scaling_recipe == 1:
323+
q_descale = torch.tensor([[1.0] * int(q_seqlen * batch_size)] * nheads, dtype=torch.float32, device='cuda').T
324+
k_descale = torch.tensor([[1.0] * int((seqlen + 223) / 224) * batch_size] * nkvheads, dtype=torch.float32, device='cuda').T
325+
v_descale = torch.tensor([[1.0] * int((seqlen + 223) / 224) * batch_size] * nkvheads, dtype=torch.float32, device='cuda').T
326+
else:
327+
raise ValueError(f"Unsupported scaling recipe: {scaling_recipe}")
328+
329+
# print(f"{q_descale.shape=}, {q_descale.stride()=}, {k_descale.shape=}, {k_descale.stride()=}", flush=True)
285330

286331
# f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
287332
f = time_fwd(
288333
_flash_attn_forward,
289334
q,
290335
k,
291336
v,
292-
softmax_scale,
337+
None, # k_new,
338+
None, # v_new,
339+
None, # qv,
340+
None, # out,
341+
None, # cu_seqlens_q,
342+
None, # cu_seqlens_k,
343+
None, # cu_seqlens_k_new,
344+
None, # seqused_q,
345+
None, # seqused_k,
346+
None, # max_seqlen_q,
347+
None, # max_seqlen_k,
348+
None, # page_table,
349+
None, # kv_batch_idx,
350+
None, # leftpad_k,
351+
None, # rotary_cos,
352+
None, # rotary_sin,
353+
None, # seqlens_rotary,
354+
q_descale=q_descale,
355+
k_descale=k_descale,
356+
v_descale=v_descale,
357+
softmax_scale=softmax_scale,
293358
causal=causal,
294359
window_size=(-1,-1),
295-
descale_q=descale_q,
296-
descale_k=descale_k,
297-
descale_v=descale_v,
298360
repeats=repeats,
299-
verbose=False
361+
verbose=False,
362+
scaling_recipe=scaling_recipe,
300363
)
301364

302365
# res = flash_attn_func(q, k, v, causal=causal)
@@ -340,12 +403,16 @@ def time_fwd(func, *args, **kwargs):
340403
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
341404
for method in methods:
342405
speed_f[config, method] = efficiency(
343-
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
406+
flops(batch_size, q_seqlen, seqlen, headdim, nheads, causal, mode="fwd"),
407+
time_f[config, method]
408+
)
409+
mem_bw_f[config, method] = mem_bw(
410+
data_size(batch_size, q_seqlen, seqlen, headdim, nheads, nkvheads, 1 if dtype == torch.float8_e4m3fn else 2, mode="fwd"),
344411
time_f[config, method]
345412
)
346413
#print (time_f[config,method])
347414
print(
348-
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
415+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {mem_bw_f[config, method]:.2f} TB/s, {time_f[config, method] * 1e3} ms, "
349416
)
350417

351418

hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ struct CollectiveMainloopFwdSm90 {
883883
// TODO: uncomment this for cp.async.
884884
// copy(scale_copy_v_per_block, tVgV_per_block_scale(_, n_block), tVsV_per_block_scale(_, smem_pipe_write.index()));
885885
// TODO: comment out this line to use cp.async.
886-
tVsV_per_block_scale(_0{}, smem_pipe_write.index()) = tVgV_per_block_scale(_0{}, n_block);
886+
copy(tVgV_per_block_scale(_, n_block), tVsV_per_block_scale(_, smem_pipe_write.index()));
887887
}
888888
}
889889
transpose_V(smem_pipe_write.index());

hopper/test_flash_attn.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,22 @@
4949

5050

5151
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
52-
# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))
53-
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn] if not DISABLE_FP8 else [])
52+
@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))
53+
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn] if not DISABLE_FP8 else [])
5454
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
5555
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
56-
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
57-
@pytest.mark.parametrize("mha_type", ["mha"])
56+
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
57+
# @pytest.mark.parametrize("mha_type", ["mha"])
5858
# @pytest.mark.parametrize("has_qv", [False, True])
5959
@pytest.mark.parametrize("has_qv", [False])
60-
# @pytest.mark.parametrize("deterministic", [False, True])
61-
@pytest.mark.parametrize("deterministic", [False])
62-
# @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))
63-
@pytest.mark.parametrize("softcap", [0.0])
64-
# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else []))
65-
@pytest.mark.parametrize("local", [False])
66-
# @pytest.mark.parametrize("causal", [False, True])
67-
@pytest.mark.parametrize("causal", [False])
60+
@pytest.mark.parametrize("deterministic", [False, True])
61+
# @pytest.mark.parametrize("deterministic", [False])
62+
@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))
63+
# @pytest.mark.parametrize("softcap", [0.0])
64+
@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else []))
65+
# @pytest.mark.parametrize("local", [False])
66+
@pytest.mark.parametrize("causal", [False, True])
67+
# @pytest.mark.parametrize("causal", [False])
6868
# @pytest.mark.parametrize("causal", [True])
6969
# @pytest.mark.parametrize("V_colmajor", [False, True])
7070
@pytest.mark.parametrize("V_colmajor", [False])

0 commit comments

Comments
 (0)