2929 import cudnn
3030except ImportError :
3131 cudnn = None
32+ cudnn = None
3233
3334
3435def convert_to_cudnn_type (torch_type ):
@@ -198,14 +199,22 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
198199 output = torch .einsum ('bhts,bshd->bthd' , attention_drop , v )
199200 return output .to (dtype = qkv .dtype )
200201
201- def flops (batch , seqlen , headdim , nheads , causal , mode = "fwd" ):
202+ def flops (batch , q_seqlen , seqlen , headdim , nheads , causal , mode = "fwd" ):
202203 assert mode in ["fwd" , "bwd" , "fwd_bwd" ]
203- f = 4 * batch * seqlen ** 2 * nheads * headdim // (2 if causal else 1 )
204+ f = 4 * batch * q_seqlen * seqlen * nheads * headdim // (2 if causal else 1 )
204205 return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f )
205206
206207def efficiency (flop , time ):
207208 return (flop / time / 10 ** 12 ) if not math .isnan (time ) else 0.0
208209
210+ def data_size (batch , q_seqlen , seqlen , headdim , nheads , nkvheads , nbytes , mode = "fwd" ):
211+ assert mode in ["fwd" ]
212+ d_size = batch * nbytes * headdim * (q_seqlen * nheads * 2 + seqlen * nkvheads * 2 )
213+ return d_size
214+
215+ def mem_bw (nbytes , time ):
216+ return (nbytes / time / 1024 / 1024 / 1024 / 1024 ) if not math .isnan (time ) else 0.0
217+
209218def time_fwd (func , * args , ** kwargs ):
210219 time .sleep (1 ) # Sleep to avoid residual power throttling from the previous benchmark
211220 time_f = benchmark_forward (func , * args , ** kwargs )
@@ -216,30 +225,50 @@ def time_fwd(func, *args, **kwargs):
216225
217226repeats = 30
218227device = 'cuda'
219- # dtype = torch.float16
228+ dtype = torch .float16
220229dtype = torch .float8_e4m3fn
230+ is_gqa = True
221231
232+ # For prefill
233+ q_seqlen_val = None
222234# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
223- bs_seqlen_vals = [(32 , 512 ), (16 , 1024 ), (8 , 2048 ), (4 , 4096 ), (2 , 8192 ), (1 , 8192 * 2 )]
235+ # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
236+ bs_seqlen_vals = [(1 , 2048 * 8 )]
237+ # bs_seqlen_vals = [(32, 8192), (32, 2048)]
238+
239+ # For decode
240+ # q_seqlen_val = 1
241+ # q_seqlen_val = 4
242+ # bs_seqlen_vals = [(1, 128)]
224243# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2)]
225244# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
226- causal_vals = [False , True ]
227- headdim_vals = [64 , 128 , 256 ]
228- dim = 2048
245+ # bs_seqlen_vals = [(32, 8192*4), (32, 8192*2), (32, 8192), (32, 4096), (32, 2048), (32, 1024), (64, 8192*2), (128, 8192), (128, 4096), (128, 2048), (128, 1024)]
246+
247+
248+ # causal_vals = [False, True]
249+ causal_vals = [True ]
250+ # headdim_vals = [64, 128, 256]
251+ headdim_vals = [128 ]
252+ dim = 128 * 16
229253# dim = 256
230254dropout_p = 0.0
231255
232- methods = (["Pytorch" , "Flash3" ]
233- + (["cuDNN" ] if cudnn is not None else [])
234- # + (["Triton"] if attention_triton is not None else [])
235- # + (["xformers.c"] if xops is not None else [])
236- # + (["xformers.f"] if xops is not None else [])
237- )
256+ # scaling_recipe = 1
257+ scaling_recipe = 0
258+
259+ methods = (["Flash3" ])
260+ # methods = (["Pytorch", "Flash3"]
261+ # + (["cuDNN"] if cudnn is not None else [])
262+ # # + (["Triton"] if attention_triton is not None else [])
263+ # # + (["xformers.c"] if xops is not None else [])
264+ # # + (["xformers.f"] if xops is not None else [])
265+ # )
238266
239267time_f = {}
240268time_b = {}
241269time_f_b = {}
242270speed_f = {}
271+ mem_bw_f = {}
243272speed_b = {}
244273speed_f_b = {}
245274for causal in causal_vals :
@@ -248,55 +277,89 @@ def time_fwd(func, *args, **kwargs):
248277 torch .cuda .empty_cache ()
249278 config = (causal , headdim , batch_size , seqlen )
250279 nheads = dim // headdim
251- q , k , v = [torch .randn (batch_size , seqlen , nheads , headdim , device = device , dtype = torch .bfloat16 , requires_grad = False ) for _ in range (3 )]
280+ nkvheads = 1 if is_gqa else nheads
281+ if q_seqlen_val is not None :
282+ q_seqlen = q_seqlen_val
283+ q = torch .rand (batch_size , q_seqlen , nheads , headdim , device = device , dtype = torch .bfloat16 , requires_grad = False )
284+ else :
285+ q = torch .randn (batch_size , seqlen , nheads , headdim , device = device , dtype = torch .bfloat16 , requires_grad = False )
286+ q_seqlen = seqlen
287+ k , v = [torch .randn (batch_size , seqlen , nkvheads , headdim , device = device , dtype = torch .bfloat16 , requires_grad = False ) for _ in range (2 )]
252288
253- qkv = torch .stack ([q , k , v ], dim = 2 )
254- qkv = qkv .to (torch .bfloat16 )
255- f = time_fwd (attention_pytorch , qkv , dropout_p , causal = causal , repeats = repeats , verbose = False )
256- time_f [config , "Pytorch" ] = f
257- res_baseline = attention_pytorch (qkv , dropout_p , causal = causal )
258-
259- if attention_triton is not None :
260- q_transposed = q .transpose (1 , 2 ).contiguous ().to (torch .float8_e4m3fn )
261- k_transposed = k .transpose (1 , 2 ).contiguous ().to (torch .float8_e4m3fn )
262- v_transposed = v .transpose (1 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 ).to (torch .float8_e4m3fn )
263- scale = 1 / math .sqrt (headdim )
264- f = time_fwd (
265- attention_triton , q_transposed , k_transposed , v_transposed ,
266- causal , scale , repeats = 5 , verbose = False , desc = 'Triton'
267- )
268- f = time_fwd (
269- attention_triton , q_transposed , k_transposed , v_transposed ,
270- causal , scale , repeats = repeats , verbose = False , desc = 'Triton'
271- )
272- time_f [config , "Triton" ] = f
273- res = attention_triton (
274- q_transposed , k_transposed , v_transposed .permute (0 , 1 , 3 , 2 ),
275- causal , scale
276- ).half ().transpose (1 , 2 )
277- torch .testing .assert_close (res , res_baseline , atol = 0.5 , rtol = 0.5 )
289+ # qkv = torch.stack([q, k, v], dim=2)
290+ # qkv = qkv.to(torch.bfloat16)
291+ # f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
292+ # time_f[config, "Pytorch"] = f
293+ # res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
294+
295+ # if attention_triton is not None:
296+ # q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
297+ # k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
298+ # v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
299+ # scale = 1 / math.sqrt(headdim)
300+ # f = time_fwd(
301+ # attention_triton, q_transposed, k_transposed, v_transposed,
302+ # causal, scale, repeats=5, verbose=False, desc='Triton'
303+ # )
304+ # f = time_fwd(
305+ # attention_triton, q_transposed, k_transposed, v_transposed,
306+ # causal, scale, repeats=repeats, verbose=False, desc='Triton'
307+ # )
308+ # time_f[config, "Triton"] = f
309+ # res = attention_triton(
310+ # q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
311+ # causal, scale
312+ # ).half().transpose(1, 2)
313+ # torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
278314
279315 # out = torch.empty_like(q)
280316 q , k , v = q .to (dtype ), k .to (dtype ), v .to (dtype )
281317 softmax_scale = q .shape [- 1 ] ** (- 0.5 )
282- descale_q = torch .tensor ([1.0 ], dtype = torch .float32 , device = 'cuda' )
283- descale_k = torch .tensor ([1.0 ], dtype = torch .float32 , device = 'cuda' )
284- descale_v = torch .tensor ([1.0 ], dtype = torch .float32 , device = 'cuda' )
318+ if scaling_recipe == 0 :
319+ q_descale = torch .tensor ([[1.0 ] * nkvheads ] * batch_size , dtype = torch .float32 , device = 'cuda' )
320+ k_descale = torch .tensor ([[1.0 ] * nkvheads ] * batch_size , dtype = torch .float32 , device = 'cuda' )
321+ v_descale = torch .tensor ([[1.0 ] * nkvheads ] * batch_size , dtype = torch .float32 , device = 'cuda' )
322+ elif scaling_recipe == 1 :
323+ q_descale = torch .tensor ([[1.0 ] * int (q_seqlen * batch_size )] * nheads , dtype = torch .float32 , device = 'cuda' ).T
324+ k_descale = torch .tensor ([[1.0 ] * int ((seqlen + 223 ) / 224 ) * batch_size ] * nkvheads , dtype = torch .float32 , device = 'cuda' ).T
325+ v_descale = torch .tensor ([[1.0 ] * int ((seqlen + 223 ) / 224 ) * batch_size ] * nkvheads , dtype = torch .float32 , device = 'cuda' ).T
326+ else :
327+ raise ValueError (f"Unsupported scaling recipe: { scaling_recipe } " )
328+
329+ # print(f"{q_descale.shape=}, {q_descale.stride()=}, {k_descale.shape=}, {k_descale.stride()=}", flush=True)
285330
286331 # f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
287332 f = time_fwd (
288333 _flash_attn_forward ,
289334 q ,
290335 k ,
291336 v ,
292- softmax_scale ,
337+ None , # k_new,
338+ None , # v_new,
339+ None , # qv,
340+ None , # out,
341+ None , # cu_seqlens_q,
342+ None , # cu_seqlens_k,
343+ None , # cu_seqlens_k_new,
344+ None , # seqused_q,
345+ None , # seqused_k,
346+ None , # max_seqlen_q,
347+ None , # max_seqlen_k,
348+ None , # page_table,
349+ None , # kv_batch_idx,
350+ None , # leftpad_k,
351+ None , # rotary_cos,
352+ None , # rotary_sin,
353+ None , # seqlens_rotary,
354+ q_descale = q_descale ,
355+ k_descale = k_descale ,
356+ v_descale = v_descale ,
357+ softmax_scale = softmax_scale ,
293358 causal = causal ,
294359 window_size = (- 1 ,- 1 ),
295- descale_q = descale_q ,
296- descale_k = descale_k ,
297- descale_v = descale_v ,
298360 repeats = repeats ,
299- verbose = False
361+ verbose = False ,
362+ scaling_recipe = scaling_recipe ,
300363 )
301364
302365 # res = flash_attn_func(q, k, v, causal=causal)
@@ -340,12 +403,16 @@ def time_fwd(func, *args, **kwargs):
340403 print (f"### causal={ causal } , headdim={ headdim } , batch_size={ batch_size } , seqlen={ seqlen } ###" )
341404 for method in methods :
342405 speed_f [config , method ] = efficiency (
343- flops (batch_size , seqlen , headdim , nheads , causal , mode = "fwd" ),
406+ flops (batch_size , q_seqlen , seqlen , headdim , nheads , causal , mode = "fwd" ),
407+ time_f [config , method ]
408+ )
409+ mem_bw_f [config , method ] = mem_bw (
410+ data_size (batch_size , q_seqlen , seqlen , headdim , nheads , nkvheads , 1 if dtype == torch .float8_e4m3fn else 2 , mode = "fwd" ),
344411 time_f [config , method ]
345412 )
346413 #print (time_f[config,method])
347414 print (
348- f"{ method } fwd: { speed_f [config , method ]:.2f} TFLOPs/s, { time_f [config , method ] * 1e3 } ms, "
415+ f"{ method } fwd: { speed_f [config , method ]:.2f} TFLOPs/s, { mem_bw_f [ config , method ]:.2f } TB/s, { time_f [config , method ] * 1e3 } ms, "
349416 )
350417
351418
0 commit comments