4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import fbgemm_gpu .experimental .gen_ai # noqa: F401
7
8
import torch
8
9
from torch ._inductor .utils import do_bench_using_profiling
9
10
10
11
11
12
class SweepHeuristics :
12
13
def __init__ (self ):
13
- self .m = 1
14
+ self .ms = [ 1 , 2 , 3 , 4 ]
14
15
self .block_dims = [
15
16
(32 , 1 ),
16
17
(32 , 4 ),
@@ -35,52 +36,89 @@ def __init__(self):
35
36
]
36
37
self .nks = [(1280 , 8192 ), (8192 , 1024 ), (7168 , 8192 ), (8192 , 3584 )]
37
38
38
- def sweep_heuristics (self , fn ) -> None :
39
- for n , k in self .nks :
40
- x = torch . randn ( size = ( self . m , k ), dtype = torch . half , device = "cuda" )
41
- w = torch .randn (size = (n , k ), dtype = torch .half , device = "cuda" )
42
- best_elapsed_time , best_block_dim_x , best_block_dim_y = None , None , None
39
+ def sweep_heuristics (self , fn , quantize_w = False , quantize_x = False ) -> None :
40
+ for m in self .ms :
41
+ for n , k in self . nks :
42
+ x = torch .randn (size = (m , k ), dtype = torch .bfloat16 , device = "cuda" )
43
+ w = torch . randn ( size = ( n , k ), dtype = torch . bfloat16 , device = "cuda" )
43
44
44
- for block_dim_x , block_dim_y in self .block_dims :
45
- if (
46
- (k % block_dim_x != 0 )
47
- or (n % block_dim_x != 0 )
48
- or ((k / block_dim_x ) % 8 != 0 )
49
- ):
50
- continue
51
- # This requires
52
- # 1. update for "testing purpose" the pytorch custom gemv op to accept additional params block_dim_x and block_dim_y
53
- # 2. modify the corresponding `{precision}_fast_gemv.cu` kernel signature to reflect the block_dim_x and block_dim_y heuristics
54
- # e.g. https://www.internalfb.com/code/fbsource/[208a27f25373]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py?lines=375
55
- res = do_bench_using_profiling (
56
- lambda func = fn , x = x , w = w , block_dim_x = block_dim_x , block_dim_y = block_dim_y : func (
57
- x , w , block_dim_x , block_dim_y
58
- )
59
- )
45
+ best_elapsed_time , best_block_dim_x , best_block_dim_y = None , None , None
46
+
47
+ for block_dim_x , block_dim_y in self .block_dims :
48
+ if (
49
+ (k % block_dim_x != 0 )
50
+ or (n % block_dim_x != 0 )
51
+ or ((k / block_dim_x ) % 8 != 0 )
52
+ ):
53
+ continue
54
+
55
+ res = 0.0
56
+ if quantize_w and quantize_x :
57
+ xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (x )
58
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
59
+ res = do_bench_using_profiling (
60
+ lambda func = fn , x = xq , w = wq , scale = x_scale * w_scale , block_dim_x = block_dim_x , block_dim_y = block_dim_y : func (
61
+ x , w , scale , block_dim_x , block_dim_y
62
+ )
63
+ )
64
+ elif quantize_w :
65
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
66
+ res = do_bench_using_profiling (
67
+ lambda func = fn , x = x , w = wq , scale = w_scale , block_dim_x = block_dim_x , block_dim_y = block_dim_y : func (
68
+ x , w , scale , block_dim_x , block_dim_y
69
+ )
70
+ )
71
+ else :
72
+ # This requires
73
+ # 1. update for "testing purpose" the pytorch custom gemv op to accept additional params block_dim_x and block_dim_y
74
+ # 2. modify the corresponding `{precision}_fast_gemv.cu` kernel signature to reflect the block_dim_x and block_dim_y heuristics
75
+ # e.g. https://www.internalfb.com/code/fbsource/[208a27f25373]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py?lines=375
76
+ res = do_bench_using_profiling (
77
+ lambda func = fn , x = x , w = w , block_dim_x = block_dim_x , block_dim_y = block_dim_y : func (
78
+ x , w , block_dim_x , block_dim_y
79
+ )
80
+ )
60
81
61
- if best_elapsed_time is None or res < best_elapsed_time :
62
- best_elapsed_time , best_block_dim_x , best_block_dim_y = (
63
- res ,
64
- block_dim_x ,
65
- block_dim_y ,
82
+ if best_elapsed_time is None or res < best_elapsed_time :
83
+ best_elapsed_time , best_block_dim_x , best_block_dim_y = (
84
+ res ,
85
+ block_dim_x ,
86
+ block_dim_y ,
87
+ )
88
+ if best_elapsed_time is None :
89
+ print ("Error: No valid elapsed time found. Exiting the function." )
90
+ return
91
+ if fn == torch .ops .fbgemm .bf16_fast_gemv :
92
+ bw = (
93
+ (m * k * 2 + n * k * 2 + m * n * 2 )
94
+ / (best_elapsed_time / 1000 )
95
+ / (1024 ** 3 )
96
+ )
97
+ elif fn == torch .ops .fbgemm .bf16fp8bf16_fast_gemv :
98
+ bw = (
99
+ (m * k * 2 + n * k + m * n * 2 )
100
+ / (best_elapsed_time / 1000 )
101
+ / (1024 ** 3 )
102
+ )
103
+ else : # Assuming fn is torch.ops.fbgemm.fp8fp8bf16_fast_gemv
104
+ bw = (
105
+ (m * k + n * k + m * n * 2 )
106
+ / (best_elapsed_time / 1000 )
107
+ / (1024 ** 3 )
66
108
)
67
- if best_elapsed_time is None :
68
- print ("Error: No valid elapsed time found. Exiting the function." )
69
- return
70
- bw = (
71
- (self .m * k * 2 + n * k * 2 + self .m * n * 2 )
72
- / (best_elapsed_time / 1000 )
73
- / (1024 ** 3 )
74
- )
75
- print (f"m: { self .m } , n: { n } , k: { k } " )
76
- print (f"tuning heuristics for kernel: { fn .__name__ } " )
77
- print (f"best elapsed time: { best_elapsed_time } ms" )
78
- print (f"best block_dim_x: { best_block_dim_x } " )
79
- print (f"best block_dim_y: { best_block_dim_y } " )
80
- print (f"best bw: { bw } GB/s" )
109
+ print (f"m: { m } , n: { n } , k: { k } " )
110
+ print (f"tuning heuristics for kernel: { fn .__name__ } " )
111
+ print (f"best elapsed time: { best_elapsed_time } ms" )
112
+ print (f"best block_dim_x: { best_block_dim_x } " )
113
+ print (f"best block_dim_y: { best_block_dim_y } " )
114
+ print (f"best bw: { bw } GB/s" )
81
115
82
116
83
117
sweep_instance = SweepHeuristics ()
84
118
sweep_instance .sweep_heuristics (fn = torch .ops .fbgemm .bf16_fast_gemv )
85
- sweep_instance .sweep_heuristics (fn = torch .ops .fbgemm .bf16fp8bf16_fast_gemv )
86
- sweep_instance .sweep_heuristics (fn = torch .ops .fbgemm .fp8fp8bf16_fast_gemv )
119
+ sweep_instance .sweep_heuristics (
120
+ fn = torch .ops .fbgemm .bf16fp8bf16_fast_gemv , quantize_w = True
121
+ )
122
+ sweep_instance .sweep_heuristics (
123
+ fn = torch .ops .fbgemm .fp8fp8bf16_fast_gemv , quantize_w = True , quantize_x = True
124
+ )
0 commit comments