Skip to content

Commit 4381afc

Browse files
leejnaunv-yunzheqclaude
authored
bench(moe_deepseek): fix moe benchmark (supersedes #2886) (#3292)
<!-- .github/pull_request_template.md --> ## 📌 Description This refreshes `benchmarks/bench_moe_deepseek.py` so it runs cleanly on current main. It rebases Yunzhe Qiu's bench rewrite (`5677a080` from #2886, which restructures the bench so autotune runs inside the `bench_gpu_time` measurement region) onto post-#3252 main, plus a small follow-up that fixes a stale `RoutingMethodType` import. Two commits from the original #2886 are intentionally dropped because their fixes have since landed independently. `c0b80b64`'s `num_tokens <= max_num_tokens` prealloc guard is now subsumed by #3252 — the `use_prealloc` predicate in `cute_dsl/fused_moe.py` already includes that check. And `f3beb602`'s `_force_autotune_off()` bench-side workaround for CUPTI measurement pollution is no longer needed: #3126 moved the cache lookup ahead of `_prepare_input_tensors` synthesis in the autotuner's tuning-mode loop, eliminating the pollution at the source. The only remaining mismatch with current main was `RoutingMethodType`, which moved from `flashinfer.fused_moe.core` to `flashinfer.tllm_enums` (and is re-exported via `flashinfer.fused_moe`) — fixed in the second commit here. Verified on B200 inside `nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc14`: DeepSeek-V3 at bs=128 ep=8 measures CuteDSL=0.147 ms / TRTLLM=0.144 ms — in the clean band that matches prior post-pollution-fix measurements (~0.157 / ~0.142). An 18-cell matrix (N=1, 8, 128, 512, 2048, 16384 × EP=1, 8, 16) and an 8-cell gen-phase decode sweep also ran without errors. Closes #2886. ## 🔍 Related Issues #2886 #3126 #3252 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Refactor** * Improved MoE throughput benchmarking methodology with enhanced pre-warm invocation and synchronization for accurate timing capture * Refactored autotuning strategy to occur inline during benchmark warmup phase * Reorganized benchmark output display for clearer result presentation [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3292) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Yunzhe Qiu <yunzheq@nvidia.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4dba29f commit 4381afc

1 file changed

Lines changed: 53 additions & 242 deletions

File tree

benchmarks/bench_moe_deepseek.py

Lines changed: 53 additions & 242 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,15 @@ def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices):
342342
"topk_indices": ti,
343343
}
344344

345+
# Pre-warm: run once on the default stream so that autotuning (which
346+
# allocates tensors and profiles multiple tactics) finishes before
347+
# bench_gpu_time moves execution to a side stream for CUDA-graph
348+
# capture. Autotuning on a non-default stream triggers illegal-
349+
# memory-access errors in the CuteDSL persistent-tile-scheduler
350+
# kernels.
351+
run(**input_kwargs)
352+
torch.cuda.synchronize()
353+
345354
times = bench_gpu_time(
346355
run,
347356
dry_run_iters=warmup,
@@ -449,6 +458,10 @@ def run(hidden, sf, router_logits, routing_bias, topk_values, topk_indices):
449458
"topk_indices": ti,
450459
}
451460

461+
# warmup and autotune
462+
run(**input_kwargs)
463+
torch.cuda.synchronize()
464+
452465
times = bench_gpu_time(
453466
run,
454467
dry_run_iters=warmup,
@@ -470,9 +483,8 @@ def bench_trtllm(
470483
use_cuda_graph=True,
471484
use_cupti=True,
472485
):
473-
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
486+
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe, RoutingMethodType
474487
from flashinfer.fused_moe.core import (
475-
RoutingMethodType,
476488
_maybe_get_cached_w3_w1_permute_indices,
477489
get_w2_permute_indices_with_cache,
478490
)
@@ -575,6 +587,10 @@ def run(routing_logits, routing_bias, hidden_states, hidden_states_scale):
575587
"hidden_states_scale": hsc,
576588
}
577589

590+
# warmup and autotune
591+
run(**input_kwargs)
592+
torch.cuda.synchronize()
593+
578594
times = bench_gpu_time(
579595
run,
580596
dry_run_iters=warmup,
@@ -592,218 +608,6 @@ def run(routing_logits, routing_bias, hidden_states, hidden_states_scale):
592608
# =============================================================================
593609

594610

595-
def run_autotune(inputs, verbose=True):
596-
from flashinfer.fused_moe import (
597-
fused_topk_deepseek,
598-
cutlass_fused_moe,
599-
trtllm_fp4_block_scale_moe,
600-
)
601-
from flashinfer.fused_moe.core import (
602-
RoutingMethodType,
603-
_maybe_get_cached_w3_w1_permute_indices,
604-
get_w2_permute_indices_with_cache,
605-
)
606-
from flashinfer import cute_dsl_fused_moe_nvfp4
607-
from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout
608-
from flashinfer.fp4_quantization import fp4_quantize, block_scale_interleave
609-
from flashinfer.autotuner import autotune
610-
611-
if verbose:
612-
print("\nRunning autotune warmup for all backends...")
613-
print("-" * 80)
614-
615-
n, sv, dev = inputs["router_logits"].shape[0], 16, "cuda"
616-
gs1 = torch.tensor([1.0], device=dev)
617-
618-
tv = torch.empty(n, CFG.top_k, dtype=torch.float32, device=dev)
619-
ti = torch.empty(n, CFG.top_k, dtype=torch.int32, device=dev)
620-
fused_topk_deepseek(
621-
scores=inputs["router_logits"],
622-
bias=inputs["routing_bias"].float(),
623-
n_group=CFG.n_group,
624-
topk_group=CFG.topk_group,
625-
topk=CFG.top_k,
626-
routed_scaling_factor=CFG.routed_scaling_factor,
627-
topk_values=tv,
628-
topk_indices=ti,
629-
)
630-
631-
# -------------------------------------------------------------------------
632-
# CuteDSL autotune
633-
# -------------------------------------------------------------------------
634-
if verbose:
635-
print("Autotuning CuteDSL...")
636-
637-
xf, xs = fp4_quantize(inputs["hidden_bf16"], gs1, sv, False, False)
638-
xs = xs.unsqueeze(-1)
639-
640-
w1i = interleave(inputs["w1_bf16"], 64)
641-
w1f = w1i.view(CFG.num_experts * 2 * CFG.intermediate_size, CFG.hidden_size)
642-
w1q, w1s = fp4_quantize(w1f, gs1, sv, False, True)
643-
w1q = w1q.view(CFG.num_experts, 2 * CFG.intermediate_size, CFG.hidden_size // 2)
644-
w1s = convert_sf_to_mma_layout(
645-
w1s, 2 * CFG.intermediate_size, CFG.hidden_size, CFG.num_experts, sv
646-
)
647-
648-
w2f = inputs["w2_bf16"].view(
649-
CFG.num_experts * CFG.hidden_size, CFG.intermediate_size
650-
)
651-
w2q, w2s = fp4_quantize(w2f, gs1, sv, False, True)
652-
w2q = w2q.view(CFG.num_experts, CFG.hidden_size, CFG.intermediate_size // 2)
653-
w2s = convert_sf_to_mma_layout(
654-
w2s, CFG.hidden_size, CFG.intermediate_size, CFG.num_experts, sv
655-
)
656-
657-
alpha, fc2sc = (
658-
torch.ones(CFG.num_experts, device=dev),
659-
torch.tensor([1.0], device=dev),
660-
)
661-
662-
with autotune(True):
663-
for _ in range(10):
664-
cute_dsl_fused_moe_nvfp4(
665-
x=xf,
666-
x_sf=xs,
667-
token_selected_experts=ti,
668-
token_final_scales=tv,
669-
w1_weight=w1q,
670-
w1_weight_sf=w1s,
671-
w1_alpha=alpha,
672-
fc2_input_scale=fc2sc,
673-
w2_weight=w2q,
674-
w2_weight_sf=w2s,
675-
w2_alpha=alpha,
676-
num_experts=CFG.num_experts,
677-
top_k=CFG.top_k,
678-
num_local_experts=CFG.num_experts,
679-
local_expert_offset=0,
680-
)
681-
torch.cuda.synchronize()
682-
683-
# -------------------------------------------------------------------------
684-
# CUTLASS autotune
685-
# -------------------------------------------------------------------------
686-
if verbose:
687-
print("Autotuning CUTLASS...")
688-
689-
a1_gs = torch.tensor(1.0, device=dev, dtype=torch.float32)
690-
a2_gs = torch.tensor(1.0, device=dev, dtype=torch.float32)
691-
quant_scales = [
692-
a1_gs,
693-
inputs["w1_sf"].view(torch.int32),
694-
1.0 / (a1_gs * inputs["w1_gs"]),
695-
a2_gs,
696-
inputs["w2_sf"].view(torch.int32),
697-
1.0 / (a2_gs * inputs["w2_gs"]),
698-
]
699-
hidden_fp4, input_sf = fp4_quantize(inputs["hidden_bf16"], a1_gs, sv, False, True)
700-
output_cutlass = torch.empty(n, CFG.hidden_size, dtype=torch.bfloat16, device=dev)
701-
702-
with autotune(True):
703-
for _ in range(10):
704-
cutlass_fused_moe(
705-
hidden_fp4,
706-
ti.to(torch.int),
707-
tv,
708-
inputs["w1_fp4"].contiguous().view(torch.long),
709-
inputs["w2_fp4"].contiguous().view(torch.long),
710-
torch.bfloat16,
711-
quant_scales=quant_scales,
712-
input_sf=input_sf,
713-
output=output_cutlass,
714-
)
715-
torch.cuda.synchronize()
716-
717-
# -------------------------------------------------------------------------
718-
# TRTLLM Gen autotune
719-
# -------------------------------------------------------------------------
720-
if verbose:
721-
print("Autotuning TRTLLM Gen...")
722-
723-
etm, cache = 128, {}
724-
hg = inputs["hidden_gs"]
725-
hfp, hsf = fp4_quantize(inputs["hidden_bf16"], hg, sv, False, True)
726-
hfp = hfp.view(torch.uint8).reshape(n, CFG.hidden_size // 2)
727-
hsc = (
728-
hsf.view(torch.float8_e4m3fn)
729-
.flatten()[: n * CFG.hidden_size // sv]
730-
.reshape(n, CFG.hidden_size // sv)
731-
)
732-
733-
def prep(bf16, gs, M, K):
734-
fl, sl = [], []
735-
for e in range(CFG.num_experts):
736-
q, s = fp4_quantize(bf16[e], gs[e], sv, False, False)
737-
fl.append(q.view(torch.uint8).reshape(M, K // 2))
738-
sl.append(s.view(torch.float8_e4m3fn).reshape(M, K // sv))
739-
return torch.stack(fl), torch.stack(sl)
740-
741-
w1f_trt, w1s_trt = prep(
742-
inputs["w1_bf16"], inputs["w1_gs"], 2 * CFG.intermediate_size, CFG.hidden_size
743-
)
744-
w2f_trt, w2s_trt = prep(
745-
inputs["w2_bf16"], inputs["w2_gs"], CFG.hidden_size, CFG.intermediate_size
746-
)
747-
748-
def shuf(fp4, sf, perm_fn):
749-
fsh, ssh = [], []
750-
for i in range(CFG.num_experts):
751-
p = perm_fn(cache, fp4[i], etm)
752-
fsh.append(fp4[i][p.to(dev)].contiguous())
753-
ps = perm_fn(cache, sf[i].view(torch.uint8), etm, sv)
754-
ssh.append(
755-
block_scale_interleave(sf[i].view(torch.uint8)[ps.to(dev)].contiguous())
756-
)
757-
return torch.stack(fsh), torch.stack(ssh)
758-
759-
w1f_trt, w1s_trt = shuf(w1f_trt, w1s_trt, _maybe_get_cached_w3_w1_permute_indices)
760-
w2f_trt, w2s_trt = shuf(w2f_trt, w2s_trt, get_w2_permute_indices_with_cache)
761-
w1s_trt = w1s_trt.view(torch.float8_e4m3fn).reshape(
762-
CFG.num_experts, 2 * CFG.intermediate_size, CFG.hidden_size // sv
763-
)
764-
w2s_trt = w2s_trt.view(torch.float8_e4m3fn).reshape(
765-
CFG.num_experts, CFG.hidden_size, CFG.intermediate_size // sv
766-
)
767-
768-
sc = torch.ones(CFG.num_experts, device=dev, dtype=torch.float32)
769-
770-
with autotune(True):
771-
for _ in range(10):
772-
trtllm_fp4_block_scale_moe(
773-
routing_logits=inputs["router_logits"],
774-
routing_bias=inputs["routing_bias"],
775-
hidden_states=hfp,
776-
hidden_states_scale=hsc,
777-
gemm1_weights=w1f_trt,
778-
gemm1_weights_scale=w1s_trt,
779-
gemm1_bias=None,
780-
gemm1_alpha=None,
781-
gemm1_beta=None,
782-
gemm1_clamp_limit=None,
783-
gemm2_weights=w2f_trt,
784-
gemm2_weights_scale=w2s_trt,
785-
gemm2_bias=None,
786-
output1_scale_scalar=sc,
787-
output1_scale_gate_scalar=sc,
788-
output2_scale_scalar=sc,
789-
num_experts=CFG.num_experts,
790-
top_k=CFG.top_k,
791-
n_group=CFG.n_group,
792-
topk_group=CFG.topk_group,
793-
intermediate_size=CFG.intermediate_size,
794-
local_expert_offset=0,
795-
local_num_experts=CFG.num_experts,
796-
routed_scaling_factor=CFG.routed_scaling_factor,
797-
routing_method_type=RoutingMethodType.DeepSeekV3,
798-
do_finalize=True,
799-
)
800-
torch.cuda.synchronize()
801-
802-
if verbose:
803-
print("-" * 80)
804-
print("Autotune complete for all backends.\n")
805-
806-
807611
# =============================================================================
808612
# Main Benchmark
809613
# =============================================================================
@@ -834,12 +638,23 @@ def run_benchmark(
834638
"""
835639
Unified benchmark for DeepSeek-V3 MoE backends.
836640
641+
Autotuning is merged into the benchmark runs: wrapping the benchmark loop
642+
in ``autotune(True)`` causes each backend's first invocation to profile
643+
all tactics (during ``bench_gpu_time`` dry-run warmup), with subsequent
644+
calls using the cached best tactic. This guarantees the autotuner sees
645+
the same API (wrapper vs functional), EP config, and weight shapes as the
646+
timed runs, avoiding cache-key mismatches.
647+
648+
All output is buffered and printed after the benchmark (and autotuning)
649+
completes, so autotuner log messages do not interleave with the results
650+
table.
651+
837652
Args:
838653
token_counts: List of token counts to benchmark
839654
warmup: Warmup iterations
840655
iters: Benchmark iterations
841656
ep_config: Expert Parallelism config (1, 8, or 16)
842-
do_autotune: Whether to run autotune before benchmarking
657+
do_autotune: Whether to autotune during benchmarking
843658
verbose: Print results to stdout
844659
use_cuda_graph: Whether to use CUDA graph for benchmarking
845660
use_cupti: Whether to use CUPTI for accurate GPU timing
@@ -849,19 +664,34 @@ def run_benchmark(
849664
Returns:
850665
List of BenchResult objects
851666
"""
667+
import contextlib
668+
669+
from flashinfer.autotuner import autotune
670+
852671
# Get EP configuration
853672
ep_cfg = EP_CONFIGS.get(ep_config, EP_CONFIGS[1])
854673
num_local = ep_cfg["num_local_experts"]
855674
local_offset = ep_cfg["local_expert_offset"]
856675

857-
# Run autotune if requested (BEFORE printing header to avoid interleaved output)
858-
if do_autotune:
859-
run_autotune(
860-
create_inputs(max(token_counts), routing_bias_scale=routing_bias_scale),
861-
verbose=verbose,
862-
)
676+
results = []
677+
rows_and_histograms = []
678+
679+
with autotune(True) if do_autotune else contextlib.nullcontext():
680+
for n in token_counts:
681+
row, histogram_record = _benchmark_single(
682+
n,
683+
warmup,
684+
iters,
685+
num_local,
686+
local_offset,
687+
use_cuda_graph,
688+
use_cupti,
689+
use_wrapper=use_wrapper,
690+
routing_bias_scale=routing_bias_scale,
691+
)
692+
results.extend(row)
693+
rows_and_histograms.append((row, histogram_record))
863694

864-
# Print header AFTER autotune completes
865695
if verbose:
866696
_print_header(
867697
ep_config,
@@ -870,27 +700,8 @@ def run_benchmark(
870700
use_cupti,
871701
routing_bias_scale,
872702
)
873-
874-
# Run benchmarks
875-
results = []
876-
for n in token_counts:
877-
row, histogram_record = _benchmark_single(
878-
n,
879-
warmup,
880-
iters,
881-
num_local,
882-
local_offset,
883-
use_cuda_graph,
884-
use_cupti,
885-
use_wrapper=use_wrapper,
886-
routing_bias_scale=routing_bias_scale,
887-
)
888-
results.extend(row)
889-
if verbose:
703+
for row, histogram_record in rows_and_histograms:
890704
_print_row(row, histogram_record)
891-
892-
# Print footer
893-
if verbose:
894705
_print_footer(ep_config, num_local)
895706

896707
return results

0 commit comments

Comments
 (0)