Skip to content

Commit 7a12730

Browse files
[mxfp8 moe training] add CUDA kernel for per group blocked layout with groups along M
stack-info: PR: #3546, branch: danielvegamyhre/stack/92
1 parent f023fab commit 7a12730

File tree

7 files changed

+814
-10
lines changed

7 files changed

+814
-10
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py renamed to benchmarks/prototype/moe_training/mxfp8/bench_mx_block_rearrange_2d_M_groups.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from benchmarks.utils import benchmark_cuda_function_in_microseconds
1717
from torchao.prototype.moe_training.kernels.mxfp8 import (
18+
mx_block_rearrange_2d_M_groups_cuda,
1819
torch_to_blocked_2d_M_groups,
1920
triton_mx_block_rearrange_2d_M_groups,
2021
)
@@ -30,14 +31,18 @@
3031
class ExperimentConfig:
3132
input_shape: tuple[int]
3233
num_groups: int
34+
chunk_width: int
35+
chunks_per_tb: int
3336

3437

3538
@dataclass(frozen=True)
3639
class ExperimentResult:
3740
torch_time_us: float
3841
triton_time_us: float
42+
cuda_time_us: float
3943
torch_mem_bw_gbps: float
4044
triton_mem_bw_gbps: float
45+
cuda_mem_bw_gbps: float
4146

4247

4348
@dataclass(frozen=True)
@@ -47,29 +52,39 @@ class Experiment:
4752

4853

4954
def get_configs() -> List[ExperimentConfig]:
50-
# Llama4 shapes. Input activations are scaled along K dim.
5155
block_size = 32
5256
input_shapes = [
53-
(16640, 5120 // block_size),
57+
(16384, 5120 // block_size),
5458
(131072, 5120 // block_size),
59+
(131072, 2048 // block_size),
60+
(131072, 7168 // block_size),
5561
]
56-
num_groups = [16]
62+
num_groups = [8]
63+
chunk_width_list = [64]
64+
chunks_per_tb_list = [1, 4, 8]
65+
5766
configs = []
58-
for shape, groups in itertools.product(
67+
for shape, groups, chunk_width, chunks_per_tb in itertools.product(
5968
input_shapes,
6069
num_groups,
70+
chunk_width_list,
71+
chunks_per_tb_list,
6172
):
6273
configs.append(
6374
ExperimentConfig(
6475
input_shape=shape,
6576
num_groups=groups,
77+
chunk_width=chunk_width,
78+
chunks_per_tb=chunks_per_tb,
6679
)
6780
)
6881
return configs
6982

7083

7184
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7285
input_shape, num_groups = config.input_shape, config.num_groups
86+
chunk_width, chunks_per_tb = config.chunk_width, config.chunks_per_tb
87+
7388
input_tensor = torch.randint(
7489
low=0,
7590
high=256,
@@ -107,6 +122,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
107122
input_group_offsets,
108123
)
109124

125+
# bench CUDA pipelined kernel with configured chunk_width and chunks_per_tb
126+
_ = mx_block_rearrange_2d_M_groups_cuda(
127+
input_tensor.view(torch.uint8),
128+
input_group_offsets.to(torch.int32),
129+
chunk_width,
130+
chunks_per_tb,
131+
)
132+
cuda_time_us = benchmark_cuda_function_in_microseconds(
133+
mx_block_rearrange_2d_M_groups_cuda,
134+
input_tensor.view(torch.uint8),
135+
input_group_offsets.to(torch.int32),
136+
chunk_width,
137+
chunks_per_tb,
138+
)
139+
110140
# mem bw calculations
111141
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
112142
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
@@ -116,23 +146,28 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
116146

117147
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
118148
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
149+
cuda_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (cuda_time_us / 1e6)
119150

120151
return ExperimentResult(
121152
torch_time_us=torch_time_us,
122153
triton_time_us=triton_time_us,
154+
cuda_time_us=cuda_time_us,
123155
torch_mem_bw_gbps=torch_mem_bw_gbps,
124156
triton_mem_bw_gbps=triton_mem_bw_gbps,
157+
cuda_mem_bw_gbps=cuda_mem_bw_gbps,
125158
)
126159

127160

128161
def print_results(experiments: List[Experiment]):
129162
headers = [
130163
"input_shape",
164+
"chunk_width",
165+
"chunks_per_tb",
131166
"torch_time_us",
132167
"triton_time_us",
133-
"torch_mem_bw_gbps",
134-
"triton_mem_bw_gbps",
168+
"cuda_time_us",
135169
"triton_speedup",
170+
"cuda_speedup",
136171
]
137172
rows = []
138173
for experiment in experiments:
@@ -142,11 +177,13 @@ def print_results(experiments: List[Experiment]):
142177
rows.append(
143178
[
144179
input_shape,
145-
experiment.result.torch_time_us,
146-
experiment.result.triton_time_us,
147-
round(experiment.result.torch_mem_bw_gbps, 3),
148-
round(experiment.result.triton_mem_bw_gbps, 3),
180+
experiment.config.chunk_width,
181+
experiment.config.chunks_per_tb,
182+
f"{experiment.result.torch_time_us:.2f}",
183+
f"{experiment.result.triton_time_us:.2f}",
184+
f"{experiment.result.cuda_time_us:.2f}",
149185
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
186+
f"{experiment.result.torch_time_us / experiment.result.cuda_time_us:.2f}x",
150187
]
151188
)
152189
print(tabulate(rows, headers=headers))

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ def get_extensions():
709709
mxfp8_sources = [
710710
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
711711
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
712+
os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_M_groups.cu"),
712713
]
713714

714715
# Only add the extension if the source files exist AND we are building for sm100

test/prototype/moe_training/test_kernels.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
triton_fp8_per_group_rowwise_scales,
2222
)
2323
from torchao.prototype.moe_training.kernels.mxfp8 import (
24+
mx_block_rearrange_2d_M_groups_cuda,
2425
mxfp8_quantize_cuda_3d,
2526
torch_to_blocked_2d_K_groups,
2627
torch_to_blocked_2d_M_groups,
@@ -244,6 +245,50 @@ def test_triton_mx_block_rearrange_2d_M_groups(
244245
)
245246

246247

248+
@pytest.mark.skipif(
249+
not is_sm_at_least_100(),
250+
reason="MXFP8 requires CUDA capability 10.0 or greater",
251+
)
252+
@skip_if_rocm("ROCm enablement in progress")
253+
@pytest.mark.parametrize("m,k,n_groups", [(16640, 2048, 8), (131072, 8192, 32)])
254+
@pytest.mark.parametrize("chunk_width", [64, 128])
255+
@pytest.mark.parametrize("chunks_per_tb", [1, 4, 8, 16])
256+
def test_cuda_mx_block_rearrange_2d_M_groups(
257+
m: int,
258+
k: int,
259+
n_groups: int,
260+
chunk_width: int,
261+
chunks_per_tb: int,
262+
):
263+
device = "cuda"
264+
block_size = 32
265+
input_data = torch.randn(m, k, device=device)
266+
e8m0_scales, _ = to_mx(
267+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
268+
)
269+
scale_rows, scale_cols = e8m0_scales.shape
270+
271+
input_group_offsets = generate_jagged_offs(
272+
n_groups, m, multiple_of=block_size, device=device
273+
)
274+
275+
# torch reference
276+
ref_out_scales, _ = torch_to_blocked_2d_M_groups(
277+
e8m0_scales, input_group_offsets, block_size=block_size
278+
)
279+
280+
# cuda kernel
281+
cuda_out_scales = mx_block_rearrange_2d_M_groups_cuda(
282+
e8m0_scales,
283+
input_group_offsets,
284+
chunk_width=chunk_width,
285+
chunks_per_tb=chunks_per_tb,
286+
)
287+
assert torch.allclose(ref_out_scales, cuda_out_scales, atol=0, rtol=0), (
288+
"blocked scales not equal"
289+
)
290+
291+
247292
@skip_if_rocm("ROCm enablement in progress")
248293
@pytest.mark.parametrize("e,n,k", [(1, 8192, 5120), (2, 8192, 5120), (8, 5120, 8192)])
249294
def test_mxfp8_per_group_blocked_scales_3d(

0 commit comments

Comments
 (0)