perf(autotuner): replace power-of-2 token buckets with hybrid spacing & fix missing routing_replay_out arg#3115
Conversation
📝 WalkthroughWalkthroughReplaces power-of-2 token-bucketing with a new four-phase hybrid bucketing across MoE and GEMM autotuning and callsites; threads an optional Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Client/Caller
participant Tuner as Autotuner/Tuner
participant Utils as Bucketing Utils
participant Runner as MoE Runner
participant Kernel as trtllm_fp8_per_tensor_scale_moe_op
Client->>Tuner: request tuning / forward(input with num_tokens)
Tuner->>Utils: get_hybrid_num_tokens_buckets(max_tokens)
Tuner->>Utils: map_to_hybrid_bucket(num_tokens, max_tokens)
Tuner->>Runner: select tactic / provide mapped bucket
Client->>Runner: forward(..., routing_replay_out=?)
Runner->>Kernel: call trtllm_fp8_per_tensor_scale_moe_op(..., routing_replay_out)
Kernel-->>Runner: result
Runner-->>Client: output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~28 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request replaces the power-of-2 token bucket generation logic with a hybrid approach across several modules to improve autotuning for MoE workloads. The new logic uses four phases with varying spacing, including power-of-2 and linear steps. Additionally, a routing_replay_out parameter is added to the MoE forward functions. A logic inconsistency was identified in get_hybrid_num_tokens_buckets where the generated buckets may not align with the mapping function when min_num_tokens is greater than one, which could lead to autotuner failures.
| m = max(min_num_tokens, 1) | ||
| while m <= min(max_num_tokens, _PHASE1_END): | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END] | ||
| m = _PHASE1_END + _PHASE2_STEP | ||
| while m <= min(max_num_tokens, _PHASE2_END): | ||
| buckets.append(m) | ||
| m += _PHASE2_STEP | ||
|
|
||
| # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END] | ||
| m = _PHASE2_END + _PHASE3_STEP | ||
| while m <= min(max_num_tokens, _PHASE3_END): | ||
| buckets.append(m) | ||
| m += _PHASE3_STEP | ||
|
|
||
| # Phase 4: power-of-2 beyond _PHASE3_END | ||
| m = _PHASE3_END * 2 | ||
| while m <= max_num_tokens: | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| if not buckets or buckets[-1] != max_num_tokens: | ||
| buckets.append(max_num_tokens) | ||
|
|
||
| return tuple(sorted(set(buckets))) |
There was a problem hiding this comment.
The implementation of get_hybrid_num_tokens_buckets has a critical inconsistency with map_to_hybrid_bucket when min_num_tokens > 1.
- Phase 1 Mismatch: If
min_num_tokensis not a power of 2 (e.g., 10), Phase 1 currently generates buckets starting from that value (e.g.,[10, 20, 40, ...]). However,map_to_hybrid_bucketusesnext_positive_power_of_2(x)for Phase 1, which means an input of size 10 will map to bucket 16. Since 16 is not in the generated list, the autotuner will fail to find a tuned tactic for this size. - Phase 2-4 Filtering: The loops for subsequent phases use fixed starting points (e.g.,
_PHASE1_END + _PHASE2_STEP), which results in buckets smaller thanmin_num_tokensbeing added to the list ifmin_num_tokensis large.
The robust fix is to always generate the full set of potential buckets starting from 1 (to ensure consistency with the mapping logic) and then filter the final result to keep only those within the [min_num_tokens, max_num_tokens] range.
buckets: List[int] = []
# Phase 1: power-of-2 up to _PHASE1_END
m = 1
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2
# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
m = _PHASE1_END + _PHASE2_STEP
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP
# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
m = _PHASE2_END + _PHASE3_STEP
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP
# Phase 4: power-of-2 beyond _PHASE3_END
m = _PHASE3_END * 2
while m <= max_num_tokens:
buckets.append(m)
m *= 2
if not buckets or buckets[-1] != max_num_tokens:
buckets.append(max_num_tokens)
return tuple(sorted(set(b for b in buckets if b >= min_num_tokens and b <= max_num_tokens)))There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 217-223: The docstring in fused_moe/utils.py (around the function
that describes the four phases) contains Unicode multiplication characters "×"
which trigger Ruff; replace those with the ASCII letter "x" (e.g., change "step
×2" to "step x2") so the docstring uses plain ASCII and pre-commit passes;
update all occurrences in that docstring text accordingly.
- Around line 224-253: get_hybrid_num_tokens_buckets is not honoring
min_num_tokens across phases: phase1 starts at min_num_tokens without rounding
up to the next power-of-2, and phases 2/3 always start at fixed boundaries
(e.g., _PHASE1_END+_PHASE2_STEP) which can emit values below min_num_tokens. Fix
by computing phase starts relative to min_num_tokens: for phase1 set m to the
smallest power-of-2 >= min_num_tokens (use bit math or loop) and then multiply
by 2; for phase2 set m to the smallest value >= min_num_tokens and >=
(_PHASE1_END+_PHASE2_STEP) that aligns to the _PHASE2_STEP grid (ceil to next
multiple of _PHASE2_STEP); for phase3 do the same alignment with _PHASE3_STEP
and _PHASE2_END; and for phase4 start at max(min_num_tokens, _PHASE3_END*2) then
multiply by 2; ensure every appended bucket >= min_num_tokens and <=
max_num_tokens and keep the final sorting/unique logic intact (variables:
get_hybrid_num_tokens_buckets, min_num_tokens, max_num_tokens, _PHASE1_END,
_PHASE2_STEP, _PHASE2_END, _PHASE3_STEP, _PHASE3_END).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f83d9649-3a2d-43ac-98e8-a7f2b490a9f6
📥 Commits
Reviewing files that changed from the base of the PR and between 7d0f68e and 5027cbff0ca18416abce41dd16c53a30e4cf9d1e.
📒 Files selected for processing (5)
flashinfer/fused_moe/core.pyflashinfer/fused_moe/cute_dsl/tuner.pyflashinfer/fused_moe/utils.pyflashinfer/gemm/gemm_base.pyflashinfer/trtllm_low_latency_gemm.py
| This function uses four phases with progressively coarser spacing:: | ||
|
|
||
| Phase 1: [min .. 256] — power-of-2 (step ×2) | ||
| Phase 2: (256 .. 2048] — linear step 256 | ||
| Phase 3: (2048 .. 4096] — linear step 512 | ||
| Phase 4: (4096 .. max] — power-of-2 (step ×2) | ||
| """ |
There was a problem hiding this comment.
Replace ambiguous multiplication signs in the docstring.
Ruff flags the Unicode × characters here; use plain x to keep pre-commit clean.
Proposed fix
- Phase 1: [min .. 256] — power-of-2 (step ×2)
+ Phase 1: [min .. 256] — power-of-2 (step x2)
Phase 2: (256 .. 2048] — linear step 256
Phase 3: (2048 .. 4096] — linear step 512
- Phase 4: (4096 .. max] — power-of-2 (step ×2)
+ Phase 4: (4096 .. max] — power-of-2 (step x2)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| This function uses four phases with progressively coarser spacing:: | |
| Phase 1: [min .. 256] — power-of-2 (step ×2) | |
| Phase 2: (256 .. 2048] — linear step 256 | |
| Phase 3: (2048 .. 4096] — linear step 512 | |
| Phase 4: (4096 .. max] — power-of-2 (step ×2) | |
| """ | |
| This function uses four phases with progressively coarser spacing:: | |
| Phase 1: [min .. 256] — power-of-2 (step x2) | |
| Phase 2: (256 .. 2048] — linear step 256 | |
| Phase 3: (2048 .. 4096] — linear step 512 | |
| Phase 4: (4096 .. max] — power-of-2 (step x2) | |
| """ |
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 219-219: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
[warning] 222-222: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/utils.py` around lines 217 - 223, The docstring in
fused_moe/utils.py (around the function that describes the four phases) contains
Unicode multiplication characters "×" which trigger Ruff; replace those with the
ASCII letter "x" (e.g., change "step ×2" to "step x2") so the docstring uses
plain ASCII and pre-commit passes; update all occurrences in that docstring text
accordingly.
|
/bot run |
samuellees
left a comment
There was a problem hiding this comment.
LGTM, waiting for the CI pass
|
/bot run |
|
/bot run |
|
[FAILED] Pipeline #49156002: 1/20 passed |
|
Hi @StudyingShao , Could you please:
Thx! |
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Head branch was pushed to by a user without write access
ee52bd9 to
d2de788
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/fused_moe/utils.py (1)
236-239:⚠️ Potential issue | 🟡 MinorReplace Unicode
×with ASCIIxin docstring (RUF002).Ruff 0.15.10 still flags lines 236 and 239. Swap
×forxto keep pre-commit clean.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/utils.py` around lines 236 - 239, Replace the Unicode multiplication symbol '×' with ASCII 'x' in the docstring block that lists "Phase 1" through "Phase 4" (the lines showing steps like "step ×2") in fused_moe.utils so the text reads "step x2" (and similarly for any other occurrences), commit the change to satisfy the RUF002 warning.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 273-289: The branch in map_to_hybrid_bucket currently returns
next_positive_power_of_2(x) when x <= _PHASE1_END, which can exceed
max_num_tokens; change that branch to clamp the result (return
min(next_positive_power_of_2(x), max_num_tokens)) so the function always honors
the [1, max_num_tokens] contract referenced in the docstring (update the branch
handling x <= _PHASE1_END in map_to_hybrid_bucket to use min(...,
max_num_tokens) and keep other branches unchanged).
---
Duplicate comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 236-239: Replace the Unicode multiplication symbol '×' with ASCII
'x' in the docstring block that lists "Phase 1" through "Phase 4" (the lines
showing steps like "step ×2") in fused_moe.utils so the text reads "step x2"
(and similarly for any other occurrences), commit the change to satisfy the
RUF002 warning.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: dc7330a7-f65b-476d-98ef-c2db6d3f60ee
📥 Commits
Reviewing files that changed from the base of the PR and between ee52bd937b2dd434e797ae72d3d47ec89b37d387 and d2de788.
📒 Files selected for processing (5)
flashinfer/fused_moe/core.pyflashinfer/fused_moe/cute_dsl/tuner.pyflashinfer/fused_moe/utils.pyflashinfer/gemm/gemm_base.pyflashinfer/trtllm_low_latency_gemm.py
🚧 Files skipped from review as they are similar to previous changes (3)
- flashinfer/trtllm_low_latency_gemm.py
- flashinfer/gemm/gemm_base.py
- flashinfer/fused_moe/core.py
| def map_to_hybrid_bucket(x: int, max_num_tokens: int) -> int: | ||
| """Map an arbitrary num_tokens to the nearest hybrid bucket (rounding up). | ||
|
|
||
| Mirrors the four-phase spacing of :func:`get_hybrid_num_tokens_buckets`. | ||
| The result is clamped to ``[1, max_num_tokens]``. | ||
| """ | ||
| if x <= 0: | ||
| return 1 | ||
| if x >= max_num_tokens: | ||
| return max_num_tokens | ||
| if x <= _PHASE1_END: | ||
| return next_positive_power_of_2(x) | ||
| if x <= _PHASE2_END: | ||
| return min(_ceil_to_step(x, _PHASE2_STEP), max_num_tokens) | ||
| if x <= _PHASE3_END: | ||
| return min(_ceil_to_step(x, _PHASE3_STEP), max_num_tokens) | ||
| return min(next_positive_power_of_2(x), max_num_tokens) |
There was a problem hiding this comment.
Edge case: map_to_hybrid_bucket can exceed max_num_tokens when max_num_tokens < 256.
For x in (0, max_num_tokens) with max_num_tokens <= _PHASE1_END, the branch at line 283-284 returns next_positive_power_of_2(x) without the max_num_tokens clamp, which can exceed the stated [1, max_num_tokens] contract. Example: map_to_hybrid_bucket(70, 100) returns 128.
All current callsites in this PR pass 8192, so this is not actively exploited — but the docstring guarantees clamping unconditionally, and the returned value won't exist in get_hybrid_num_tokens_buckets(100)'s output, which could cause silent autotuner profile mismatches if someone adopts the API with a small cap in the future.
🛡️ Proposed fix
if x <= _PHASE1_END:
- return next_positive_power_of_2(x)
+ return min(next_positive_power_of_2(x), max_num_tokens)
if x <= _PHASE2_END:
return min(_ceil_to_step(x, _PHASE2_STEP), max_num_tokens)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/utils.py` around lines 273 - 289, The branch in
map_to_hybrid_bucket currently returns next_positive_power_of_2(x) when x <=
_PHASE1_END, which can exceed max_num_tokens; change that branch to clamp the
result (return min(next_positive_power_of_2(x), max_num_tokens)) so the function
always honors the [1, max_num_tokens] contract referenced in the docstring
(update the branch handling x <= _PHASE1_END in map_to_hybrid_bucket to use
min(..., max_num_tokens) and keep other branches unchanged).
|
/bot run |
|
/bot run |
|
/bot run |
|
/bot run |
…nfiguration
Adds six no-GPU pytest cases at
`tests/moe/test_cute_dsl_fused_moe.py::TestAutotunerBucketConfig`
guarding the autotuner bucket-cap fix and locking in the load-bearing
behavioral parity with TRT-LLM's pattern at
`cute_dsl_custom_ops.py:2390-2391` and `2700-2703`.
Three "no hardcoded cap" regression guards (the load-bearing
property of the fix):
1. `test_gen_tuning_buckets_is_callable_not_static_tuple` — pins
`gen_tuning_buckets` on the runner's `tuning_config` to be a bare
callable, not a pre-computed tuple.
2. `test_gen_tuning_buckets_no_hardcoded_8192_cap` — verifies that
calling the configured `gen_tuning_buckets` with input dims 8192,
16384, and 32768 produces bucket sets whose maximum reflects the
input value.
3. `test_map_to_tuning_buckets_above_8192_not_capped` — verifies
that `map_to_tuning_buckets(x)` for x ∈ {16384, 32768, 65536}
doesn't cap at 8192. Ensures we use `map_to_hybrid_bucket_uncapped`
instead of `lambda x: map_to_hybrid_bucket(x, 8192)`.
Three TRT-LLM-parity regression guards (lock in the
behavioral-equivalence-where-achievable):
4. `test_map_to_tuning_buckets_phase1_matches_trtllm_at_powers_of_2` —
pins fi/trt-llm parity at power-of-2 inputs ≤ 256 (hybrid Phase 1,
where pure power-of-2 spacing is preserved). At these inputs,
fi's `map_to_tuning_buckets(x)` must equal x and equal
`last_positive_power_of_2(x)` (TRT-LLM's pattern).
5. `test_map_to_tuning_buckets_is_monotonic` — pins monotonic
non-decreasing behavior across hybrid Phases 1-4. TRT-LLM's
`last_positive_power_of_2` and fi's `map_to_hybrid_bucket_uncapped`
both satisfy this; catches a regression that would introduce
non-monotonic mapping.
6. `test_gen_tuning_buckets_covers_trtllm_power_of_2_points` — pins
that fi's hybrid bucket set is a SUPERSET of TRT-LLM's power-of-2
bucket set at every max_n tested. The hybrid scheme intentionally
adds intermediate linear-step buckets in Phase 2/3 (per PR flashinfer-ai#3115's
perf rationale) but must preserve the coarse-grained power-of-2
coverage TRT-LLM has.
These six tests together enforce: (a) no hardcoded cap, (b) callable
form, (c) TRT-LLM-equivalence at power-of-2 probe points, (d)
monotonicity, (e) coarse-grained coverage parity with TRT-LLM. The
hybrid-vs-power-of-2 deviation in Phase 2/3/4 is intentional and
documented (PR flashinfer-ai#3115); the tests don't enforce parity in those phases
because that would regress fi's deliberate perf optimization.
All tests are pure-Python and run without a GPU. They construct a
`CuteDslFusedMoENvfp4Runner` with a no-op `forward_impl` to inspect
its `tuning_config`; no GPU, no CuteDSL kernel binaries, no autotune
side effects.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Address gemini-code-assist review on PR flashinfer-ai#3216: the test was importing `get_last_power_of_2_num_tokens_buckets` from `flashinfer.fused_moe.utils`, but PR flashinfer-ai#3115 (merged 2026-04-24) removed that function in favor of the hybrid bucket scheme. The import would have caused an ImportError when the test was collected. Replace the call with an equivalent inline construction that mirrors TRT-LLM's `get_last_power_of_2_num_tokens_buckets` (in `tensorrt_llm/_torch/utils.py:291`): powers of 2 from 1 up to `last_positive_power_of_2(max_n)`. `last_positive_power_of_2` is still available in `flashinfer.fused_moe.utils`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…me input (#3216) <!-- .github/pull_request_template.md --> ## 📌 Description The autotuner's `DynamicTensorSpec` in `flashinfer/fused_moe/cute_dsl/tuner.py` declared `gen_tuning_buckets` as the pre-computed tuple `get_hybrid_num_tokens_buckets(8192)` and `map_to_tuning_buckets` as `lambda x: map_to_hybrid_bucket(x,8192)`. The hardcoded 8192 cap silently clamped any runtime workload larger than that to the 8192-bucket's cached tactic — at DeepSeek-V3 prefill (N=16384) fi profiled at half the per-expert workload and used a tactic optimized for the wrong shape. This PR replaces the pre-computed tuple with the bare callable form (`get_hybrid_num_tokens_buckets`) and switches the mapper to the uncapped variant `map_to_hybrid_bucket_uncapped` (added alongside the hybrid-bucket scheme for exactly this case). The autotuner now invokes them with the actual input dim at autotune time, matching TRT-LLM's pattern at `cute_dsl_custom_ops.py:2390-2391` and flashinfer's own pattern at `gemm/gemm_base.py:_FP8_GEMM_SM100_TUNING_CONFIG`. ## 🔍 Related Issues #3171 #3198 #3115 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * MoE autotuner now uses uncapped dynamic hybrid bucket mapping instead of a fixed-bounded set, improving adaptation to varying input token sizes. * **Tests** * Added offline tests validating autotuner bucket configuration: dynamic bucket generation, responsiveness to input size, monotonic mapping behavior, large-input scaling, and alignment with expected power-of-2 bucket values. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📌 Description
This PR includes two improvements:
perf(autotuner): Replace power-of-2 token buckets with hybrid spacing — Pure power-of-2 spacing creates huge gaps at large values (e.g. a jump from 1024 to 2048), forcing the autotuner to pick a kernel optimised for a very different workload size. The new hybrid scheme uses four phases with progressively coarser spacing:
[min .. 256]— power-of-2 (step ×2)(256 .. 2048]— linear step 256(2048 .. 4096]— linear step 512(4096 .. max]— power-of-2 (step ×2)All callsites in MoE, GEMM, and low-latency GEMM autotuners are updated to use the new
get_hybrid_num_tokens_buckets/map_to_hybrid_bucketAPI.fix: Pass missing
routing_replay_outarg totrtllm_fp8_per_tensor_scale_moe— Two call sites infused_moe/core.pywere missing therouting_replay_outargument, causing it to be silently dropped.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Changed files:
flashinfer/fused_moe/utils.py— Core implementation: newget_hybrid_num_tokens_buckets,map_to_hybrid_bucket,map_to_hybrid_bucket_uncapped; removed oldget_last_power_of_2_num_tokens_bucketsflashinfer/fused_moe/core.py— Updated all MoE autotuner callsites + added missingrouting_replay_outargflashinfer/fused_moe/cute_dsl/tuner.py— Updated CuTe DSL FP4 MoE tuner callsiteflashinfer/gemm/gemm_base.py— Updated GEMM (FP8, BF16, FP4, MXFP8, TGV) autotuner configsflashinfer/trtllm_low_latency_gemm.py— Updated low-latency GEMM autotuner configSummary by CodeRabbit
Improvements
New Features