Skip to content

fix: Update top-k generation not to produce duplicate expert ids#3208

Open
djns99 wants to merge 4 commits into
flashinfer-ai:mainfrom
djns99:dastokes/topk_flashinfer_test
Open

fix: Update top-k generation not to produce duplicate expert ids#3208
djns99 wants to merge 4 commits into
flashinfer-ai:mainfrom
djns99:dastokes/topk_flashinfer_test

Conversation

@djns99
Copy link
Copy Markdown
Contributor

@djns99 djns99 commented Apr 30, 2026

📌 Description

Top-K ids in MOE should not have duplicate expert IDs. While this doesn't cause issues in most cases, there are a number of places that use randint() to achieve this which can lead to unexpected behaviour in downstream code that assumes there are no duplicates

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor
    • Standardized per‑token top‑k expert selection across routing, tuning, benchmarks, docs, and tests by introducing a shared helper that generates consistent, no‑replacement top‑k samples. Updated initializers, examples, test fixtures, and benchmarks to use the unified generator; added module logging for diagnostic visibility.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Replaces ad-hoc per-token expert sampling with a new helper make_random_topk_ids(...) and updates benchmarks, tuner, core, utils, doctests, and tests to use it for generating packed/top-k expert ID tensors used during tuning and tests.

Changes

Top-k expert ID generation

Layer / File(s) Summary
New Utility
flashinfer/fused_moe/utils.py
Adds make_random_topk_ids(num_experts, num_tokens, top_k, device) that samples per-token, no-replacement top-k expert IDs via torch.multinomial, returns [num_tokens, top_k] int32 tensor, handles zero-size cases, and logs when top_k > num_experts.
Data shape / sizing
flashinfer/fused_moe/core.py, flashinfer/fused_moe/cute_dsl/tuner.py
Compute num_tokens with math.prod(shapes[:-1]) to drive the utility; add math and make_random_topk_ids imports.
Core initialization
flashinfer/fused_moe/core.py
Replace torch.randint sampling in _init_packed_topk_ids with make_random_topk_ids(...).view(shapes), preserving packing/shift logic.
Auto-tuner wiring
flashinfer/fused_moe/cute_dsl/tuner.py, benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Replace per-token torch.randint usage with make_random_topk_ids(...) for dummy/top-k tensors used by autotuner and benchmark packing.
Doc / Examples
flashinfer/fused_moe/cute_dsl/moe_utils.py
Docstring example in moe_sort updated to import and use make_random_topk_ids(..., device) instead of torch.randint.
Tests
tests/autotuner/*, tests/comm/test_trtllm_moe_alltoall.py
Replace torch.randint/torch.randperm-based generation of topk_ids/token_selected_experts with make_random_topk_ids(...) (including top_k=1 single-expert cases).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • bkryu
  • yzh119
  • cyx-6
  • aleozlx
  • samuellees
  • yongwww
  • sricketts
  • nv-yunzheq

Poem

🐰 I hop through tensors, quick and bright,
I pick top-k experts by helper light.
No more randint scatter in my lair,
A tidy sampler sorts them with care.
Carrots, bits, and code — hop, everywhere!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: preventing duplicate expert IDs in MOE top-k generation, which aligns with the primary objective of the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description adequately explains the problem (duplicate expert IDs in top-k MOE selections) and mentions that pre-commit checks were run and tests were updated. However, the description lacks specific detail about which files were changed and what the solution entails.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new utility function, make_random_topk_ids, designed to generate unique expert indices for Mixture-of-Experts (MoE) layers by using multinomial sampling without replacement. This function replaces existing torch.randint calls in benchmarks, core MoE logic, and integration tests to ensure more realistic and valid test data. However, two critical bugs were identified in the communication tests where the num_tokens argument was incorrectly passed, leading to shape mismatches between the generated expert IDs and other related tensors.

Comment thread tests/comm/test_trtllm_moe_alltoall.py Outdated
Comment thread tests/comm/test_trtllm_moe_alltoall.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 926-931: The current call to make_random_topk_ids can fail when
shapes[-1] (top_k) > num_local_experts because sampling without replacement
needs at least top_k candidates; update the initializer to ensure the sampling
pool size is at least top_k by computing a candidate_experts =
max(num_local_experts, shapes[-1]) and pass that into make_random_topk_ids (keep
the subsequent .view(shapes) and device usage), or alternatively add a clear
assertion that num_local_experts >= shapes[-1] before calling
make_random_topk_ids to avoid runtime autotuning failures.

In `@flashinfer/fused_moe/cute_dsl/tuner.py`:
- Around line 263-268: The lambda used for generating top-k ids in tuner.py
hardcodes num_experts=8 which fails when shapes[-1] (top_k) > 8; update the
lambda passed to make_random_topk_ids to set num_experts to at least the sampled
top_k (e.g., num_experts = max(configured_expert_count, shapes[-1]) or
max(runner.num_experts, shapes[-1]) if a runner/config is available) so the
sampling population >= top_k; keep the rest of the call (num_tokens,
top_k=shapes[-1], device) unchanged and reference make_random_topk_ids and the
lambda that reads shapes to locate the change.

In `@flashinfer/fused_moe/utils.py`:
- Around line 253-265: Add an explicit input guard in make_random_topk_ids:
validate that top_k is <= num_experts (and optionally top_k > 0) before calling
torch.multinomial, and raise a clear ValueError (e.g., "top_k must be <=
num_experts; got top_k=X, num_experts=Y") so callers get a descriptive error
instead of the opaque torch.multinomial failure; update the beginning of
make_random_topk_ids to perform this check and return or raise accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ff0d0a55-59d6-47d1-9937-065d55397c83

📥 Commits

Reviewing files that changed from the base of the PR and between ed70283 and dfa289209088fdfa99834a0c270c07d90374067a.

📒 Files selected for processing (7)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
  • flashinfer/fused_moe/core.py
  • flashinfer/fused_moe/cute_dsl/moe_utils.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
  • flashinfer/fused_moe/utils.py
  • tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
  • tests/comm/test_trtllm_moe_alltoall.py

Comment thread flashinfer/fused_moe/core.py
Comment thread flashinfer/fused_moe/cute_dsl/tuner.py Outdated
Comment thread flashinfer/fused_moe/utils.py
@djns99 djns99 force-pushed the dastokes/topk_flashinfer_test branch from dfa2892 to cf9d6fa Compare April 30, 2026 01:51
Comment thread flashinfer/fused_moe/utils.py
Comment thread flashinfer/fused_moe/core.py
Comment thread flashinfer/fused_moe/cute_dsl/tuner.py
@djns99 djns99 force-pushed the dastokes/topk_flashinfer_test branch from da93e42 to 039a179 Compare May 6, 2026 21:55
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/fused_moe/utils.py (1)

443-448: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Do not widen num_experts to top_k; it can generate out-of-range expert IDs.

On Line 443-448, rewriting num_experts = top_k means sampled IDs may exceed the original expert space, even though the function promises IDs in [0, num_experts). This can propagate invalid routing IDs downstream.

Suggested fix
-    if top_k > num_experts:
-        logger.debug(
-            f"top_k {top_k} is greater than num_experts {num_experts}, using top_k as num_experts"
-        )
-        num_experts = top_k
+    if top_k > num_experts:
+        raise ValueError(
+            f"top_k ({top_k}) must be <= num_experts ({num_experts}) "
+            "for sampling without replacement."
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/fused_moe/utils.py` around lines 443 - 448, The current code
mistakenly widens the expert space by setting num_experts = top_k; instead clamp
top_k to num_experts so you never sample IDs outside the original expert range:
keep num_experts unchanged and replace the assignment with top_k = num_experts
(and keep/update the logger.debug message accordingly). Make this change where
top_k and num_experts are used (the debug block containing logger.debug) so the
function still guarantees IDs in [0, num_experts).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 443-448: The current code mistakenly widens the expert space by
setting num_experts = top_k; instead clamp top_k to num_experts so you never
sample IDs outside the original expert range: keep num_experts unchanged and
replace the assignment with top_k = num_experts (and keep/update the
logger.debug message accordingly). Make this change where top_k and num_experts
are used (the debug block containing logger.debug) so the function still
guarantees IDs in [0, num_experts).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ce34e0d0-94b4-4e3e-92ca-80e9fee33e93

📥 Commits

Reviewing files that changed from the base of the PR and between da93e422e9ba1db077897cbc3094e2e5bf380cf4 and 039a179.

📒 Files selected for processing (7)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
  • flashinfer/fused_moe/core.py
  • flashinfer/fused_moe/cute_dsl/moe_utils.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
  • flashinfer/fused_moe/utils.py
  • tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
  • tests/comm/test_trtllm_moe_alltoall.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
  • tests/comm/test_trtllm_moe_alltoall.py

Copy link
Copy Markdown
Collaborator

@qiching qiching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@djns99 djns99 force-pushed the dastokes/topk_flashinfer_test branch from 039a179 to ebfc0db Compare May 6, 2026 23:25
@djns99
Copy link
Copy Markdown
Contributor Author

djns99 commented May 10, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !656 has been created, and the CI pipeline #50858279 is currently running. I'll report back once the pipeline job completes.

@djns99
Copy link
Copy Markdown
Contributor Author

djns99 commented May 11, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !656 has been created, and the CI pipeline #50964324 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants