Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8 import (
mx_block_rearrange_2d_M_groups_cuda,
torch_to_blocked_2d_M_groups,
triton_mx_block_rearrange_2d_M_groups,
)
Expand All @@ -30,14 +31,18 @@
class ExperimentConfig:
input_shape: tuple[int]
num_groups: int
chunk_width: int
chunks_per_tb: int


@dataclass(frozen=True)
class ExperimentResult:
torch_time_us: float
triton_time_us: float
cuda_time_us: float
torch_mem_bw_gbps: float
triton_mem_bw_gbps: float
cuda_mem_bw_gbps: float


@dataclass(frozen=True)
Expand All @@ -47,29 +52,39 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes. Input activations are scaled along K dim.
block_size = 32
input_shapes = [
(16640, 5120 // block_size),
(16384, 5120 // block_size),
(131072, 5120 // block_size),
(131072, 2048 // block_size),
(131072, 7168 // block_size),
]
num_groups = [16]
num_groups = [8]
chunk_width_list = [64]
chunks_per_tb_list = [1, 4, 8]

configs = []
for shape, groups in itertools.product(
for shape, groups, chunk_width, chunks_per_tb in itertools.product(
input_shapes,
num_groups,
chunk_width_list,
chunks_per_tb_list,
):
configs.append(
ExperimentConfig(
input_shape=shape,
num_groups=groups,
chunk_width=chunk_width,
chunks_per_tb=chunks_per_tb,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
input_shape, num_groups = config.input_shape, config.num_groups
chunk_width, chunks_per_tb = config.chunk_width, config.chunks_per_tb

input_tensor = torch.randint(
low=0,
high=256,
Expand Down Expand Up @@ -107,6 +122,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
input_group_offsets,
)

# bench CUDA pipelined kernel with configured chunk_width and chunks_per_tb
_ = mx_block_rearrange_2d_M_groups_cuda(
input_tensor.view(torch.uint8),
input_group_offsets.to(torch.int32),
chunk_width,
chunks_per_tb,
)
cuda_time_us = benchmark_cuda_function_in_microseconds(
mx_block_rearrange_2d_M_groups_cuda,
input_tensor.view(torch.uint8),
input_group_offsets.to(torch.int32),
chunk_width,
chunks_per_tb,
)

# mem bw calculations
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
Expand All @@ -116,23 +146,28 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:

torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
cuda_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (cuda_time_us / 1e6)

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
cuda_time_us=cuda_time_us,
torch_mem_bw_gbps=torch_mem_bw_gbps,
triton_mem_bw_gbps=triton_mem_bw_gbps,
cuda_mem_bw_gbps=cuda_mem_bw_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"chunk_width",
"chunks_per_tb",
"torch_time_us",
"triton_time_us",
"torch_mem_bw_gbps",
"triton_mem_bw_gbps",
"cuda_time_us",
"triton_speedup",
"cuda_speedup",
]
rows = []
for experiment in experiments:
Expand All @@ -142,11 +177,13 @@ def print_results(experiments: List[Experiment]):
rows.append(
[
input_shape,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
round(experiment.result.torch_mem_bw_gbps, 3),
round(experiment.result.triton_mem_bw_gbps, 3),
experiment.config.chunk_width,
experiment.config.chunks_per_tb,
f"{experiment.result.torch_time_us:.2f}",
f"{experiment.result.triton_time_us:.2f}",
f"{experiment.result.cuda_time_us:.2f}",
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
f"{experiment.result.torch_time_us / experiment.result.cuda_time_us:.2f}x",
]
)
print(tabulate(rows, headers=headers))
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def get_extensions():
mxfp8_sources = [
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_M_groups.cu"),
]

# Only add the extension if the source files exist AND we are building for sm100
Expand All @@ -722,6 +723,7 @@ def get_extensions():
include_dirs=[
mxfp8_extension_dir, # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu
],
libraries=["cuda"],
extra_compile_args={
"cxx": [
f"-DPy_LIMITED_API={min_supported_cpython_hexcode}",
Expand Down
45 changes: 45 additions & 0 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.kernels.mxfp8 import (
mx_block_rearrange_2d_M_groups_cuda,
mxfp8_quantize_cuda_3d,
torch_to_blocked_2d_K_groups,
torch_to_blocked_2d_M_groups,
Expand Down Expand Up @@ -244,6 +245,50 @@ def test_triton_mx_block_rearrange_2d_M_groups(
)


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MXFP8 requires CUDA capability 10.0 or greater",
)
@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.parametrize("m,k,n_groups", [(16640, 2048, 8), (131072, 8192, 32)])
@pytest.mark.parametrize("chunk_width", [64, 128])
@pytest.mark.parametrize("chunks_per_tb", [1, 4, 8, 16])
def test_cuda_mx_block_rearrange_2d_M_groups(
m: int,
k: int,
n_groups: int,
chunk_width: int,
chunks_per_tb: int,
):
device = "cuda"
block_size = 32
input_data = torch.randn(m, k, device=device)
e8m0_scales, _ = to_mx(
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)
scale_rows, scale_cols = e8m0_scales.shape

input_group_offsets = generate_jagged_offs(
n_groups, m, multiple_of=block_size, device=device
)

# torch reference
ref_out_scales, _ = torch_to_blocked_2d_M_groups(
e8m0_scales, input_group_offsets, block_size=block_size
)

# cuda kernel
cuda_out_scales = mx_block_rearrange_2d_M_groups_cuda(
e8m0_scales,
input_group_offsets,
chunk_width=chunk_width,
chunks_per_tb=chunks_per_tb,
)
assert torch.allclose(ref_out_scales, cuda_out_scales, atol=0, rtol=0), (
"blocked scales not equal"
)


@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.parametrize("e,n,k", [(1, 8192, 5120), (2, 8192, 5120), (8, 5120, 8192)])
def test_mxfp8_per_group_blocked_scales_3d(
Expand Down
8 changes: 7 additions & 1 deletion torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _parse_version(version_string):
# Current torchao version
(_parse_version("0.16.0.dev"), _parse_version("2.9.1")),
(_parse_version("0.16.0.dev"), _parse_version("2.10.0.dev")),
(_parse_version("0.16.0.dev"), _parse_version("2.11.0.dev")),
]

current_torch_version = _parse_version(torch.__version__)
Expand All @@ -108,10 +109,15 @@ def _parse_version(version_string):
try:
from pathlib import Path

# Load abi3 .so files and cpython .so files
so_files = list(Path(__file__).parent.glob("_C*.so"))
if len(so_files) > 0:
for file in so_files:
torch.ops.load_library(str(file))
print(f"Loading {file}")
try:
torch.ops.load_library(str(file))
except Exception as e:
print(f"Failed to load {file}: {e}")
from . import ops

# The following registers meta kernels for some CPU kernels
Expand Down
Loading
Loading