fix: remove over-strict K%4 assert in get_shuffle_matrix_sf_a_row_indices#3163
fix: remove over-strict K%4 assert in get_shuffle_matrix_sf_a_row_indices#3163jimmyzho wants to merge 2 commits into
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe assertion enforcing that K is divisible by 4 was removed from ChangesShuffle matrix logic
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request removes the K % 4 alignment assertion in the get_shuffle_matrix_sf_a_row_indices function within flashinfer/utils.py. This change allows for non-aligned K dimensions, such as those found in GPT-OSS-120B MXFP4 weights, as the function only computes row permutation indices over the M dimension and downstream kernels handle the necessary padding. I have no feedback to provide.
…ices The assertion `K % 4 == 0` in `get_shuffle_matrix_sf_a_row_indices` rejects valid inputs that the downstream kernel handles correctly. This function only computes row permutation indices over the M dimension and never uses K. The downstream `block_scale_interleave` CUDA kernel already pads K to a multiple of 4 internally via `_compute_swizzled_layout_sf_size` (round_up(total_column, 4)). Removing the assertion unblocks models with non-aligned K dimensions, e.g. K=90 in GPT-OSS-120B MXFP4 weights (issue flashinfer-ai#2122). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
20cd7be to
3d21b6f
Compare
Summary
Removes the
assert K % 4 == 0guard fromget_shuffle_matrix_sf_a_row_indicesinflashinfer/utils.py. The assertion rejects valid inputs (e.g.K=90for GPT-OSS-120B MXFP4 weights), even though the downstream kernel handles them correctly.Why it's safe to remove:
get_shuffle_matrix_sf_a_row_indicesonly computes row permutation indices over the M dimension. K is unpacked from the shape but never used.csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp:197):cols_padded = PadUpFn(cols, 4)and passesn_paddedto the kernel.csrc/nv_internal/cpp/kernels/quantization.cu:400-418): loops up tonumColsPadded, writes0for padded slots (T sf = 0; if (colIdx < numCols) sf = SFIn[...];).flashinfer/quantization/fp4_quantization.py:51-54, 297-304):_compute_swizzled_layout_sf_sizeusesround_up(total_column, 4)to size the output buffer to match.Fixes #2122.
Test plan
get_shuffle_matrix_sf_a_row_indicessucceeds forK=90(the GPT-OSS-120B case that triggered the original assertion).K=1,K=4,K=88(previously valid) still return correct row indices.shuffle_matrix_sf_apipeline (row shuffle →block_scale_interleave) produces correctly-sized outputs for non-aligned K.🤖 Generated with Claude Code
Summary by CodeRabbit