fix: Update top-k generation not to produce duplicate expert ids#3208
fix: Update top-k generation not to produce duplicate expert ids#3208djns99 wants to merge 4 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughReplaces ad-hoc per-token expert sampling with a new helper ChangesTop-k expert ID generation
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new utility function, make_random_topk_ids, designed to generate unique expert indices for Mixture-of-Experts (MoE) layers by using multinomial sampling without replacement. This function replaces existing torch.randint calls in benchmarks, core MoE logic, and integration tests to ensure more realistic and valid test data. However, two critical bugs were identified in the communication tests where the num_tokens argument was incorrectly passed, leading to shape mismatches between the generated expert IDs and other related tensors.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 926-931: The current call to make_random_topk_ids can fail when
shapes[-1] (top_k) > num_local_experts because sampling without replacement
needs at least top_k candidates; update the initializer to ensure the sampling
pool size is at least top_k by computing a candidate_experts =
max(num_local_experts, shapes[-1]) and pass that into make_random_topk_ids (keep
the subsequent .view(shapes) and device usage), or alternatively add a clear
assertion that num_local_experts >= shapes[-1] before calling
make_random_topk_ids to avoid runtime autotuning failures.
In `@flashinfer/fused_moe/cute_dsl/tuner.py`:
- Around line 263-268: The lambda used for generating top-k ids in tuner.py
hardcodes num_experts=8 which fails when shapes[-1] (top_k) > 8; update the
lambda passed to make_random_topk_ids to set num_experts to at least the sampled
top_k (e.g., num_experts = max(configured_expert_count, shapes[-1]) or
max(runner.num_experts, shapes[-1]) if a runner/config is available) so the
sampling population >= top_k; keep the rest of the call (num_tokens,
top_k=shapes[-1], device) unchanged and reference make_random_topk_ids and the
lambda that reads shapes to locate the change.
In `@flashinfer/fused_moe/utils.py`:
- Around line 253-265: Add an explicit input guard in make_random_topk_ids:
validate that top_k is <= num_experts (and optionally top_k > 0) before calling
torch.multinomial, and raise a clear ValueError (e.g., "top_k must be <=
num_experts; got top_k=X, num_experts=Y") so callers get a descriptive error
instead of the opaque torch.multinomial failure; update the beginning of
make_random_topk_ids to perform this check and return or raise accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ff0d0a55-59d6-47d1-9937-065d55397c83
📥 Commits
Reviewing files that changed from the base of the PR and between ed70283 and dfa289209088fdfa99834a0c270c07d90374067a.
📒 Files selected for processing (7)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.pyflashinfer/fused_moe/core.pyflashinfer/fused_moe/cute_dsl/moe_utils.pyflashinfer/fused_moe/cute_dsl/tuner.pyflashinfer/fused_moe/utils.pytests/autotuner/test_trtllm_fused_moe_autotuner_integration.pytests/comm/test_trtllm_moe_alltoall.py
dfa2892 to
cf9d6fa
Compare
da93e42 to
039a179
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/fused_moe/utils.py (1)
443-448:⚠️ Potential issue | 🟠 Major | ⚡ Quick winDo not widen
num_expertstotop_k; it can generate out-of-range expert IDs.On Line 443-448, rewriting
num_experts = top_kmeans sampled IDs may exceed the original expert space, even though the function promises IDs in[0, num_experts). This can propagate invalid routing IDs downstream.Suggested fix
- if top_k > num_experts: - logger.debug( - f"top_k {top_k} is greater than num_experts {num_experts}, using top_k as num_experts" - ) - num_experts = top_k + if top_k > num_experts: + raise ValueError( + f"top_k ({top_k}) must be <= num_experts ({num_experts}) " + "for sampling without replacement." + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/fused_moe/utils.py` around lines 443 - 448, The current code mistakenly widens the expert space by setting num_experts = top_k; instead clamp top_k to num_experts so you never sample IDs outside the original expert range: keep num_experts unchanged and replace the assignment with top_k = num_experts (and keep/update the logger.debug message accordingly). Make this change where top_k and num_experts are used (the debug block containing logger.debug) so the function still guarantees IDs in [0, num_experts).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 443-448: The current code mistakenly widens the expert space by
setting num_experts = top_k; instead clamp top_k to num_experts so you never
sample IDs outside the original expert range: keep num_experts unchanged and
replace the assignment with top_k = num_experts (and keep/update the
logger.debug message accordingly). Make this change where top_k and num_experts
are used (the debug block containing logger.debug) so the function still
guarantees IDs in [0, num_experts).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ce34e0d0-94b4-4e3e-92ca-80e9fee33e93
📥 Commits
Reviewing files that changed from the base of the PR and between da93e422e9ba1db077897cbc3094e2e5bf380cf4 and 039a179.
📒 Files selected for processing (7)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.pyflashinfer/fused_moe/core.pyflashinfer/fused_moe/cute_dsl/moe_utils.pyflashinfer/fused_moe/cute_dsl/tuner.pyflashinfer/fused_moe/utils.pytests/autotuner/test_trtllm_fused_moe_autotuner_integration.pytests/comm/test_trtllm_moe_alltoall.py
🚧 Files skipped from review as they are similar to previous changes (3)
- tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
- benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
- tests/comm/test_trtllm_moe_alltoall.py
039a179 to
ebfc0db
Compare
|
/bot run |
|
/bot run |
📌 Description
Top-K ids in MOE should not have duplicate expert IDs. While this doesn't cause issues in most cases, there are a number of places that use randint() to achieve this which can lead to unexpected behaviour in downstream code that assumes there are no duplicates
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit