Skip to content

Fix trace-bmm-fp8 test: B should be K-major for subword types#3184

Merged
saltyminty merged 1 commit into
flashinfer-ai:mainfrom
xrq-phys:fix/trace-bmm-fp8
May 1, 2026
Merged

Fix trace-bmm-fp8 test: B should be K-major for subword types#3184
saltyminty merged 1 commit into
flashinfer-ai:mainfrom
xrq-phys:fix/trace-bmm-fp8

Conversation

@xrq-phys
Copy link
Copy Markdown
Contributor

@xrq-phys xrq-phys commented Apr 26, 2026

📌 Description

Closes #3188

Issue: Upstream change has introduced a failing CI test case: tests/trace/test_reference_correctness.py::test_bmm_fp8_reference_correctness

Cause: flashinfer.bmm_bf16, flashinfer.bmm_fp8 (any sub-32 dtypes) expect K-major inputs. The cutlass backend checks for this but the default fp8 backend doesn't, causing wrong results.

🔍 Related Issues

Current CI runs

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests
    • Improved correctness checks for BF16 and FP8 matrix operations by normalizing inputs to match kernel layout expectations.
    • Kept original reference comparisons unchanged to ensure consistent validation.
    • Preserved behavior for skipping when kernels are unavailable and maintained existing similarity thresholds.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 26, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3bf42f40-9458-4589-b307-1bdb05004af6

📥 Commits

Reviewing files that changed from the base of the PR and between 0110434a478431168f4c61cf1c29629fa25d5fef and 135f933.

📒 Files selected for processing (1)
  • tests/trace/test_reference_correctness.py

📝 Walkthrough

Walkthrough

The tests normalize contiguity of batched b operands before invoking FlashInfer kernels: b/b_fp8 are transformed into contiguity-preserving b_kmaj/b_fp8_kmaj via transpose→contiguous→transpose and passed to flashinfer.bmm_bf16 / flashinfer.bmm_fp8; reference traces continue to use the original inputs.

Changes

Cohort / File(s) Summary
Test Preprocessing
tests/trace/test_reference_correctness.py
Create contiguity-normalized b_kmaj and b_fp8_kmaj using transpose→contiguous→transpose and pass them to flashinfer.bmm_bf16 / flashinfer.bmm_fp8; leave reference computations operating on original b / b_fp8. Cosine-similarity thresholding and skip/exception logic unchanged.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested labels

op: gemm

Suggested reviewers

  • sricketts
  • aleozlx
  • yongwww
  • saltyminty
  • yzh119

Poem

🐰 I hop through tensors, light and quick,
I flip their axes, make them stick,
I nudge them tidy, snug, and neat,
So kernels meet them, row and sheet,
Then I munch errors — tiny, slick.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly identifies the specific fix (K-major requirement) and affected test (trace-bmm-fp8), directly matching the code changes.
Description check ✅ Passed Description includes linked issue (#3188), root cause explanation, changes made, and pre-commit/test checklist completion.
Linked Issues check ✅ Passed PR fixes failing test by ensuring B tensors are K-major for FP8/BF16 bmm operations, directly addressing issue #3188's requirements.
Out of Scope Changes check ✅ Passed All changes are confined to test tensor preparation for bmm_bf16 and bmm_fp8 tests, directly related to fixing the K-major requirement issue.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

@xrq-phys
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !605 has been created, and the CI pipeline #49554921 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/trace/test_reference_correctness.py (1)

2170-2172: Correct K-major preprocessing for bmm_bf16.

The transpose(1,2).contiguous().transpose(1,2) idiom produces a (B, K, N) view with stride (N*K, 1, K) (K-stride = 1), which matches the column-major layout bmm_bf16 documents for B. Logical values are preserved, so passing original b to the reference remains correct.

Optional: a one-line comment would help future readers understand why the seemingly-noop pattern is necessary — i.e., sub-32-bit BMMs require K-major B, but only the cutlass backend enforces it.

📝 Optional clarifying comment
     a = torch.randn(B, M, K, dtype=torch.bfloat16, device="cuda")
     b = torch.randn(B, K, N, dtype=torch.bfloat16, device="cuda")
+    # bmm_bf16 requires B in K-major (column-major) layout; round-trip through
+    # contiguous() to get strides (N*K, 1, K) without changing logical values.
     b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_reference_correctness.py` around lines 2170 - 2172, The
pre-processing using b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2) is
a no-op for logical values but was used to get K-major strides required only by
the cutlass backend; update the test to pass the original b (not b_kmaj) to the
reference path and keep the cutlass call as-is (api = flashinfer.bmm_bf16(a,
b_kmaj, backend="cutlass")), and add a one-line comment near b_kmaj explaining
that the transpose/contiguous/transpose is only to enforce K-major memory layout
for cutlass and that logical values are unchanged so the reference uses the
original b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/trace/test_reference_correctness.py`:
- Around line 2170-2172: The pre-processing using b_kmaj = b.transpose(1,
2).contiguous().transpose(1, 2) is a no-op for logical values but was used to
get K-major strides required only by the cutlass backend; update the test to
pass the original b (not b_kmaj) to the reference path and keep the cutlass call
as-is (api = flashinfer.bmm_bf16(a, b_kmaj, backend="cutlass")), and add a
one-line comment near b_kmaj explaining that the transpose/contiguous/transpose
is only to enforce K-major memory layout for cutlass and that logical values are
unchanged so the reference uses the original b.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c60b9edf-6ad1-467f-958f-dcc573d8be88

📥 Commits

Reviewing files that changed from the base of the PR and between 5e1318c and 1c1a4bdf5c30db72fbbd17ffa73aa77cfcf1a6bf.

📒 Files selected for processing (1)
  • tests/trace/test_reference_correctness.py

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the bmm_bf16 and bmm_fp8 reference correctness tests to utilize K-major layout tensors when calling the FlashInfer API. Feedback suggests also using these K-major tensors in the reference implementation calls to maintain consistency and ensure that both the kernel and the reference are tested against the same memory representation.

Comment thread tests/trace/test_reference_correctness.py
Comment thread tests/trace/test_reference_correctness.py
@xrq-phys
Copy link
Copy Markdown
Contributor Author

@saltyminty could you approving / merging this PR?

#2711 SageAttn (presumably other CI runs also) is blocked by this failure.

CC @YangXu1990uiuc for vis.

Thanks!

@xrq-phys
Copy link
Copy Markdown
Contributor Author

/bot help

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

FlashInfer CI Bot

Available Commands:

  • /bot run - Mirror this PR to GitLab and run CI pipeline
  • /bot status - Check current pipeline status
  • /bot stop - Cancel running pipeline
  • /bot help - Show this help message

Authorization:

Only whitelisted users can trigger CI. Contact a maintainer for access.

How It Works:

  1. Authorized user comments /bot run on a PR
  2. Bot mirrors PR to internal GitLab
  3. GitLab CI pipeline runs automatically
  4. Results are posted back to this PR

Note: Any whitelisted user can trigger CI for any PR, not just their own.

@saltyminty
Copy link
Copy Markdown
Collaborator

CI looks good (failures are node allocation timeouts)

@saltyminty saltyminty self-assigned this Apr 27, 2026
@saltyminty saltyminty enabled auto-merge (squash) April 27, 2026 22:56
@xrq-phys
Copy link
Copy Markdown
Contributor Author

@saltyminty can we skip CI here? Or do we have to wait until nodes are back?

@saltyminty saltyminty disabled auto-merge April 28, 2026 05:32
@saltyminty
Copy link
Copy Markdown
Collaborator

We can skip internal CI since this change should be safe, but need the pre-merge checks to pass before the merge button appears

@xrq-phys xrq-phys force-pushed the fix/trace-bmm-fp8 branch from cebc7a3 to 0110434 Compare April 29, 2026 05:48
@saltyminty saltyminty enabled auto-merge (squash) April 29, 2026 06:14
@xrq-phys
Copy link
Copy Markdown
Contributor Author

Hi @saltyminty

I couldn't find the error spot of the failed test:

PR Test / AOT Build Import (x64, cu130) (pull_request)

Looks like a timeout / node failure causing all other jobs killed?

(Similar goes for #2711 )

@saltyminty saltyminty disabled auto-merge April 29, 2026 16:21
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
@saltyminty saltyminty enabled auto-merge (squash) April 30, 2026 23:01
@saltyminty saltyminty disabled auto-merge May 1, 2026 02:22
@saltyminty saltyminty merged commit f6d49c4 into flashinfer-ai:main May 1, 2026
28 of 34 checks passed
@xrq-phys xrq-phys deleted the fix/trace-bmm-fp8 branch May 1, 2026 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]test_bmm_fp8_reference_correctness fails with cos_sim=-0.0019 < 0.99

3 participants