1515
1616from benchmarks .utils import benchmark_cuda_function_in_microseconds
1717from torchao .prototype .moe_training .kernels .mxfp8 import (
18+ mx_block_rearrange_2d_M_groups_cuda ,
1819 torch_to_blocked_2d_M_groups ,
1920 triton_mx_block_rearrange_2d_M_groups ,
2021)
3031class ExperimentConfig :
3132 input_shape : tuple [int ]
3233 num_groups : int
34+ chunk_width : int
35+ chunks_per_tb : int
3336
3437
3538@dataclass (frozen = True )
3639class ExperimentResult :
3740 torch_time_us : float
3841 triton_time_us : float
42+ cuda_time_us : float
3943 torch_mem_bw_gbps : float
4044 triton_mem_bw_gbps : float
45+ cuda_mem_bw_gbps : float
4146
4247
4348@dataclass (frozen = True )
@@ -47,29 +52,39 @@ class Experiment:
4752
4853
4954def get_configs () -> List [ExperimentConfig ]:
50- # Llama4 shapes. Input activations are scaled along K dim.
5155 block_size = 32
5256 input_shapes = [
53- (16640 , 5120 // block_size ),
57+ (16384 , 5120 // block_size ),
5458 (131072 , 5120 // block_size ),
59+ (131072 , 2048 // block_size ),
60+ (131072 , 7168 // block_size ),
5561 ]
56- num_groups = [16 ]
62+ num_groups = [8 ]
63+ chunk_width_list = [64 ]
64+ chunks_per_tb_list = [1 , 4 , 8 ]
65+
5766 configs = []
58- for shape , groups in itertools .product (
67+ for shape , groups , chunk_width , chunks_per_tb in itertools .product (
5968 input_shapes ,
6069 num_groups ,
70+ chunk_width_list ,
71+ chunks_per_tb_list ,
6172 ):
6273 configs .append (
6374 ExperimentConfig (
6475 input_shape = shape ,
6576 num_groups = groups ,
77+ chunk_width = chunk_width ,
78+ chunks_per_tb = chunks_per_tb ,
6679 )
6780 )
6881 return configs
6982
7083
7184def run_experiment (config : ExperimentConfig ) -> ExperimentResult :
7285 input_shape , num_groups = config .input_shape , config .num_groups
86+ chunk_width , chunks_per_tb = config .chunk_width , config .chunks_per_tb
87+
7388 input_tensor = torch .randint (
7489 low = 0 ,
7590 high = 256 ,
@@ -107,6 +122,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
107122 input_group_offsets ,
108123 )
109124
125+ # bench CUDA pipelined kernel with configured chunk_width and chunks_per_tb
126+ _ = mx_block_rearrange_2d_M_groups_cuda (
127+ input_tensor .view (torch .uint8 ),
128+ input_group_offsets .to (torch .int32 ),
129+ chunk_width ,
130+ chunks_per_tb ,
131+ )
132+ cuda_time_us = benchmark_cuda_function_in_microseconds (
133+ mx_block_rearrange_2d_M_groups_cuda ,
134+ input_tensor .view (torch .uint8 ),
135+ input_group_offsets .to (torch .int32 ),
136+ chunk_width ,
137+ chunks_per_tb ,
138+ )
139+
110140 # mem bw calculations
111141 bytes_per_input_el = torch .finfo (torch .float8_e8m0fnu ).bits / 8
112142 bytes_per_output_el = torch .finfo (torch .float8_e4m3fn ).bits / 8
@@ -116,23 +146,28 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
116146
117147 torch_mem_bw_gbps = ((read_bytes + write_bytes ) / 1e9 ) / (torch_time_us / 1e6 )
118148 triton_mem_bw_gbps = ((read_bytes + write_bytes ) / 1e9 ) / (triton_time_us / 1e6 )
149+ cuda_mem_bw_gbps = ((read_bytes + write_bytes ) / 1e9 ) / (cuda_time_us / 1e6 )
119150
120151 return ExperimentResult (
121152 torch_time_us = torch_time_us ,
122153 triton_time_us = triton_time_us ,
154+ cuda_time_us = cuda_time_us ,
123155 torch_mem_bw_gbps = torch_mem_bw_gbps ,
124156 triton_mem_bw_gbps = triton_mem_bw_gbps ,
157+ cuda_mem_bw_gbps = cuda_mem_bw_gbps ,
125158 )
126159
127160
128161def print_results (experiments : List [Experiment ]):
129162 headers = [
130163 "input_shape" ,
164+ "chunk_width" ,
165+ "chunks_per_tb" ,
131166 "torch_time_us" ,
132167 "triton_time_us" ,
133- "torch_mem_bw_gbps" ,
134- "triton_mem_bw_gbps" ,
168+ "cuda_time_us" ,
135169 "triton_speedup" ,
170+ "cuda_speedup" ,
136171 ]
137172 rows = []
138173 for experiment in experiments :
@@ -142,11 +177,13 @@ def print_results(experiments: List[Experiment]):
142177 rows .append (
143178 [
144179 input_shape ,
145- experiment .result .torch_time_us ,
146- experiment .result .triton_time_us ,
147- round (experiment .result .torch_mem_bw_gbps , 3 ),
148- round (experiment .result .triton_mem_bw_gbps , 3 ),
180+ experiment .config .chunk_width ,
181+ experiment .config .chunks_per_tb ,
182+ f"{ experiment .result .torch_time_us :.2f} " ,
183+ f"{ experiment .result .triton_time_us :.2f} " ,
184+ f"{ experiment .result .cuda_time_us :.2f} " ,
149185 f"{ experiment .result .torch_time_us / experiment .result .triton_time_us :.2f} x" ,
186+ f"{ experiment .result .torch_time_us / experiment .result .cuda_time_us :.2f} x" ,
150187 ]
151188 )
152189 print (tabulate (rows , headers = headers ))
0 commit comments