Fix trace-bmm-fp8 test: B should be K-major for subword types#3184
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📥 CommitsReviewing files that changed from the base of the PR and between 0110434a478431168f4c61cf1c29629fa25d5fef and 135f933. 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe tests normalize contiguity of batched Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.Comment |
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/trace/test_reference_correctness.py (1)
2170-2172: Correct K-major preprocessing forbmm_bf16.The
transpose(1,2).contiguous().transpose(1,2)idiom produces a(B, K, N)view with stride(N*K, 1, K)(K-stride = 1), which matches the column-major layoutbmm_bf16documents forB. Logical values are preserved, so passing originalbto the reference remains correct.Optional: a one-line comment would help future readers understand why the seemingly-noop pattern is necessary — i.e., sub-32-bit BMMs require K-major
B, but only thecutlassbackend enforces it.📝 Optional clarifying comment
a = torch.randn(B, M, K, dtype=torch.bfloat16, device="cuda") b = torch.randn(B, K, N, dtype=torch.bfloat16, device="cuda") + # bmm_bf16 requires B in K-major (column-major) layout; round-trip through + # contiguous() to get strides (N*K, 1, K) without changing logical values. b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_reference_correctness.py` around lines 2170 - 2172, The pre-processing using b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2) is a no-op for logical values but was used to get K-major strides required only by the cutlass backend; update the test to pass the original b (not b_kmaj) to the reference path and keep the cutlass call as-is (api = flashinfer.bmm_bf16(a, b_kmaj, backend="cutlass")), and add a one-line comment near b_kmaj explaining that the transpose/contiguous/transpose is only to enforce K-major memory layout for cutlass and that logical values are unchanged so the reference uses the original b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/trace/test_reference_correctness.py`:
- Around line 2170-2172: The pre-processing using b_kmaj = b.transpose(1,
2).contiguous().transpose(1, 2) is a no-op for logical values but was used to
get K-major strides required only by the cutlass backend; update the test to
pass the original b (not b_kmaj) to the reference path and keep the cutlass call
as-is (api = flashinfer.bmm_bf16(a, b_kmaj, backend="cutlass")), and add a
one-line comment near b_kmaj explaining that the transpose/contiguous/transpose
is only to enforce K-major memory layout for cutlass and that logical values are
unchanged so the reference uses the original b.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c60b9edf-6ad1-467f-958f-dcc573d8be88
📥 Commits
Reviewing files that changed from the base of the PR and between 5e1318c and 1c1a4bdf5c30db72fbbd17ffa73aa77cfcf1a6bf.
📒 Files selected for processing (1)
tests/trace/test_reference_correctness.py
There was a problem hiding this comment.
Code Review
This pull request modifies the bmm_bf16 and bmm_fp8 reference correctness tests to utilize K-major layout tensors when calling the FlashInfer API. Feedback suggests also using these K-major tensors in the reference implementation calls to maintain consistency and ensure that both the kernel and the reference are tested against the same memory representation.
|
@saltyminty could you approving / merging this PR? #2711 SageAttn (presumably other CI runs also) is blocked by this failure. CC @YangXu1990uiuc for vis. Thanks! |
|
/bot help |
FlashInfer CI BotAvailable Commands:
Authorization:Only whitelisted users can trigger CI. Contact a maintainer for access. How It Works:
Note: Any whitelisted user can trigger CI for any PR, not just their own. |
|
CI looks good (failures are node allocation timeouts) |
1c1a4bd to
162eca5
Compare
|
@saltyminty can we skip CI here? Or do we have to wait until nodes are back? |
|
We can skip internal CI since this change should be safe, but need the pre-merge checks to pass before the merge button appears |
162eca5 to
cebc7a3
Compare
cebc7a3 to
0110434
Compare
|
Hi @saltyminty I couldn't find the error spot of the failed test:
Looks like a timeout / node failure causing all other jobs killed? (Similar goes for #2711 ) |
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
0110434 to
135f933
Compare
📌 Description
Closes #3188
Issue: Upstream change has introduced a failing CI test case:
tests/trace/test_reference_correctness.py::test_bmm_fp8_reference_correctnessCause:
flashinfer.bmm_bf16,flashinfer.bmm_fp8(any sub-32 dtypes) expect K-major inputs. Thecutlassbackend checks for this but the default fp8 backend doesn't, causing wrong results.🔍 Related Issues
Current CI runs
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit