build(jit-cache): split flashinfer-jit-cache wheels by SM family#3265
build(jit-cache): split flashinfer-jit-cache wheels by SM family#3265dierksen wants to merge 9 commits into
Conversation
The cu130 flashinfer-jit-cache wheel grew to 2.0 GB and started failing to upload as a GitHub Release asset (per-asset 2 GiB limit; see flashinfer-ai#3257). Each new SM target appended ~150-200 MB compressed to every wheel, and cu130 carries 8 (sm75/80/89/90a/100a/103a/110a/120f). Split each (CUDA, CPU-arch) wheel into three by GPU SM family: - sm9x - Ampere/Ada/Hopper (<= sm90a) - sm10x - Datacenter Blackwell (sm100a/103a/110a) - sm12x - Consumer Blackwell (sm120f, future sm121a) Same package name everywhere; the family is encoded in the PEP 440 local-version, so wheels resolve as e.g. 'flashinfer-jit-cache== 0.6.11+cu130.sm10x'. Existing 'pip install flashinfer-jit-cache' still works once the right pin is given. Driven by a new 'flashinfer install-jit-cache-wheel' subcommand that detects FlashInfer version, CUDA version, and GPU compute capability (via torch.cuda.get_device_capability) and runs the matching pip install. Honors --cuda-version, --sm-family, --nightly, --dry-run. Modeled on the CLI scaffolding from flashinfer-ai#3142 with the family dimension added. Build side: 'FLASHINFER_JIT_CACHE_SM_FAMILY' env var, when set, filters 'FLASHINFER_CUDA_ARCH_LIST' to the family's archs and appends '.<family>' to the local-version suffix. Release / nightly workflows gain an 'sm_family' matrix dimension; the upload-to-release loop iterates over all three families. The wheel-index regex accepts the new local-version shape and remains compatible with the legacy '+cuXY' format. Closes flashinfer-ai#3257 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds SM-family-aware JIT-cache wheel distribution: utilities to map/filter CUDA arches by SM family, build backend and metadata changes, a new install CLI with autodetection, CI matrix/artifact updates, wheel-index parsing, docs updates, and corresponding tests. ChangesSM Family JIT Cache Wheels
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new CLI command, flashinfer install-jit-cache-wheel, to automate the installation of flashinfer-jit-cache wheels by autodetecting the CUDA version and GPU SM family. This update supports a new distribution model where wheels are split by SM family to comply with GitHub's asset size limits. Feedback from the review suggests improving the robustness of CUDA architecture parsing to handle various string formats and refactoring duplicated SM family logic into a common utility to enhance maintainability.
Mirrors the per-family split in release/nightly so PR CI actually exercises the per-family build path. Was previously running one job per (cuda, arch) which still built every arch; now runs three jobs per (cuda, arch) — one per SM family — each compiling only its family's archs. - pr-test.yml: 'aot-build-import' and 'aot-build-import-rerun' gain 'sm_family: [sm9x, sm10x, sm12x]'. cu126 is excluded for sm10x and sm12x because that toolkit only supports archs <= sm90. The rerun matrix builder mirrors the same exclude. FLASHINFER_JIT_CACHE_SM_FAMILY is forwarded into the test container via ci/bash.sh's '-e' flag. - task_test_jit_cache_package_build_import.sh: when FLASHINFER_JIT_CACHE_SM_FAMILY is set, filter FLASHINFER_CUDA_ARCH_LIST to that family's archs before running the wheel build and verify_all_modules_compiled.py. The build-side filter in build_backend.py mutates os.environ inside its own process only, so doing it once in the parent shell ensures both subprocesses see the same arch list. Also fix black formatting flagged by pre-commit on PR flashinfer-ai#3265: - build_backend.py: rewrite SM_FAMILIES lambdas as named functions to avoid black's awkward multi-line break of '<' chained comparisons. - __main__.py: collapse a ClickException to single-line per black's preference. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer-jit-cache/build_backend.py`:
- Around line 38-94: Run the project formatter (e.g. ruff format or pre-commit
run --all-files) and commit the resulting changes so the SM_FAMILIES dict and
the multi-line print in _apply_sm_family_filter are formatted to satisfy ruff;
specifically reformat the SM_FAMILIES declaration and the print(...) call in
_apply_sm_family_filter (and any other affected lines) and push the reformatted
file so CI passes.
In `@flashinfer/__main__.py`:
- Around line 267-279: Re-run the project's formatter (ruff format) to apply the
canonical formatting for the Click exception lines in the CUDA-version parsing
block: ensure the click.ClickException(...) call around the InvalidVersion
exception handling and the earlier validation (the calls that raise
click.ClickException when normalized startswith "cu" and in the except block
that wraps InvalidVersion) are formatted according to ruff so the pre-commit
check passes; after formatting, stage and commit the changes.
- Around line 350-403: The current install_jit_cache_wheel_cmd builds an exact
pinned requirement from resolved_flashinfer_version which breaks when --nightly
points at nightly index but the installed __version__ is a stable release;
modify install_jit_cache_wheel_cmd to detect nightly and, if nightly is True and
resolved_flashinfer_version is a release (no "dev" or "+"), construct a range
requirement instead of an exact pin (e.g.
"flashinfer-jit-cache>={base},<{next_major_or_minor}") by parsing
resolved_flashinfer_version with packaging.version to compute the next version
bound, or alternatively call a new flag-aware helper (update
_build_jit_cache_requirement or add _build_jit_cache_requirement_for_nightly)
that returns the looser requirement when nightly is set; ensure the printed
requirement and pip args use this new requirement variable.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7e3e084a-a262-4e22-b1ba-8cc3a970463e
📒 Files selected for processing (7)
.github/workflows/nightly-release.yml.github/workflows/release.ymlREADME.mddocs/installation.rstflashinfer-jit-cache/build_backend.pyflashinfer/__main__.pyscripts/update_whl_index.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In @.github/workflows/pr-test.yml:
- Around line 244-247: The run step invokes ci/bash.sh with an unquoted
${DOCKER_IMAGE}, which triggers SC2086 (word-splitting); update the command to
quote the variable as "$DOCKER_IMAGE" in the invocation (e.g., change ci/bash.sh
${DOCKER_IMAGE} --no-gpu ... to ci/bash.sh "$DOCKER_IMAGE" --no-gpu ...), and
make the same change in the equivalent rerun "Run Test" step that calls the same
command line so both occurrences use "$DOCKER_IMAGE".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4f367040-09e4-47ff-8782-3ddd358e6b36
📒 Files selected for processing (4)
.github/workflows/pr-test.ymlflashinfer-jit-cache/build_backend.pyflashinfer/__main__.pyscripts/task_test_jit_cache_package_build_import.sh
The Release workflow has a 'pull_request: paths: .github/workflows/
release.yml' trigger that runs the build jobs in dry-run mode whenever
release.yml changes. Its checkout used:
ref: ${{ github.event_name == 'pull_request' && github.head_ref || inputs.tag }}
For a fork PR, github.head_ref resolves to a branch that doesn't exist
on flashinfer-ai/flashinfer (because actions/checkout defaults
'repository:' to the workflow's repo). 'git fetch' fails three times,
and the setup job dies before any actual build work runs.
The bug has been latent since flashinfer-ai#1910 (2025-10-10), where the trigger
and the buggy checkout were introduced together. It only fires on
fork-PRs that touch release.yml; PRs from branches on the main repo
work fine because the default 'repository:' already matches.
Fix all four affected checkouts to set 'repository:' explicitly to
the PR head's repo and pin to head.sha (which is also stable across
re-pushes during the run). On workflow_dispatch the existing
'inputs.tag' path is unchanged.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/jit/env.py`:
- Around line 126-127: The compatibility check currently uses
_public_package_version(flashinfer_version) !=
_public_package_version(flashinfer_jit_cache_version) but strips the
local-version suffix so different SM-family suffixes (e.g. .sm9x vs .sm12x) are
ignored; update the logic to, when CUDA is available, extract the sm* suffix
from flashinfer_jit_cache_version (e.g. via a small regex on the local-version
segment) and compare it to the detected device family (use your CUDA detection
helper / device-family variable); if the sm suffix is present and does not match
the detected device family, raise the same incompatibility error (or fail fast)
instead of proceeding, while falling back to the existing
_public_package_version check for non-CUDA cases.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 20101ea1-0e1a-416e-84cc-231857008350
📒 Files selected for processing (7)
.github/workflows/nightly-release.yml.github/workflows/pr-test.yml.github/workflows/release.ymlflashinfer/__main__.pyflashinfer/jit/env.pyscripts/task_test_jit_cache_package_build_import.shtests/cli/test_cli_cmds.py
🚧 Files skipped from review as they are similar to previous changes (4)
- scripts/task_test_jit_cache_package_build_import.sh
- .github/workflows/nightly-release.yml
- .github/workflows/pr-test.yml
- .github/workflows/release.yml
There was a problem hiding this comment.
re: the 3 GPU SM families in the PR description, I think @aleozlx mentioned earlier in the thread that each device typically requires 8.0 plus their native arch -- should we add sm80a compialtion to sm10x and sm12x subwheels as well?
# Conflicts: # scripts/task_test_jit_cache_package_build_import.sh
The cu130 flashinfer-jit-cache wheel grew to 2.0 GB and started failing to upload as a GitHub Release asset (per-asset 2 GiB limit; see #3257). Each new SM target appended ~150-200 MB compressed to every wheel, and cu130 carries 8 (sm75/80/89/90a/100a/103a/110a/120f).
Split each (CUDA, CPU-arch) wheel into three by GPU SM family:
Same package name everywhere; the family is encoded in the PEP 440 local-version, so wheels resolve as e.g. 'flashinfer-jit-cache== 0.6.11+cu130.sm10x'. Existing 'pip install flashinfer-jit-cache' still works once the right pin is given.
Driven by a new 'flashinfer install-jit-cache-wheel' subcommand that detects FlashInfer version, CUDA version, and GPU compute capability (via torch.cuda.get_device_capability) and runs the matching pip install. Honors --cuda-version, --sm-family, --nightly, --dry-run. Modeled on the CLI scaffolding from #3142 with the family dimension added.
Build side: 'FLASHINFER_JIT_CACHE_SM_FAMILY' env var, when set, filters 'FLASHINFER_CUDA_ARCH_LIST' to the family's archs and appends '.' to the local-version suffix. Release / nightly workflows gain an 'sm_family' matrix dimension; the upload-to-release loop iterates over all three families. The wheel-index regex accepts the new local-version shape and remains compatible with the legacy '+cuXY' format.
Closes #3257
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Testing Done
Local DGX Spark validation:
aarch64with NVIDIA GB10, compute capability12.1; PyTorch sees CUDA13.0and reports the device as(12, 1).flashinfer install-jit-cache-wheel --dry-runfrom the source checkout. It now resolves fromversion.txtwhen package metadata is0.0.0+unknown, detects CUDA13.0, selectssm12x, and printsflashinfer-jit-cache==0.6.11+cu130.sm12x.python -m pip, so the CLI now falls back touv pip install --python ....https://flashinfer.ai/whl/cu130; it reached the resolver cleanly and failed only because0.6.11+cu130.sm12xis not yet published.jit-cache-cu129-aarch64-sm12xartifact from Release run25528598737, served it through a local simple index, installed it with explicit--cuda-version cu129 --sm-family sm12x --index-url ..., importedflashinfer_jit_cache, verifiedFLASHINFER_AOT_DIRpoints at the installed package cache, then uninstalled it because the local machine is CUDA 13.0.jit-cache-cu130-aarch64-sm12xartifact is absent because that job failed while downloading build dependencies (IncompleteReadfornvidia_cublas), not because of the CLI install path.Merge/conflict validation:
upstream/mainand resolved conflicts inpr-test.yml,release.yml, andnightly-release.yml, preserving bothFLASHINFER_JIT_CACHE_SM_FAMILYforwarding and upstream sccache/NVCC env forwarding.yaml.safe_load.git diff --cached --check.python -m py_compile flashinfer/__main__.py flashinfer/jit/env.py tests/cli/test_cli_cmds.py.python -m pytest tests/cli/test_cli_cmds.py -q(18 passed).Review-comment follow-up:
90,100, and120in the SM-family filter, and moved the duplicated SM-family helpers intobuild_utils.pyfor reuse by both the CLI and jit-cache build backend.--nightlyconcern by rejecting nightly installs when the resolved FlashInfer version is not a dev release; dev versions still exact-pin the matching SM-family wheel, e.g.flashinfer-jit-cache==0.6.11.dev20260508+cu130.sm12x.python -m py_compile build_utils.py flashinfer/__main__.py flashinfer-jit-cache/build_backend.py tests/cli/test_cli_cmds.py.python -m pytest tests/cli/test_cli_cmds.py -q(21 passed, with the expected PyTorch GB10 capability warning from this host's torch build).git diff --check.flashinfer install-jit-cache-wheel --cuda-version cu130 --sm-family sm12x --dry-run, which resolvesflashinfer-jit-cache==0.6.11+cu130.sm12xand theuv pip install --python ...command.0.6.11now fails early with the new explanatory error, while explicit0.6.11.dev20260508resolvesflashinfer-jit-cache==0.6.11.dev20260508+cu130.sm12xagainsthttps://flashinfer.ai/whl/nightly/cu130with--pre.Human feedback follow-up:
flashinfer install-jit-cache-wheelautodetection to inspect every visible CUDA device instead of only device 0. It selects a wheel only when the visible GPUs are covered by one jit-cache SM-family wheel, and otherwise fails with guidance to pass--sm-familyor build from source with an explicitFLASHINFER_CUDA_ARCH_LIST.sm80base arch plus native Blackwell archs. The build-side family filter now keeps/adds8.0forsm10xandsm12xonly when a native arch for that family is present.sm12xdefault arch lists on12.0f; the family-specificsm120ftarget covers DGX Spark / GB10 (sm121) without adding an exact12.1atarget by default.flashinfer-jit-cachelocal-version SM suffixes. On CUDA hosts, a wrong-family installed wheel now fails fast; on this DGX Spark,0.6.11+cu130.sm12xvalidates and0.6.11+cu130.sm9xfails with an expected-family error.python -m py_compile build_utils.py flashinfer/__main__.py flashinfer/jit/env.py flashinfer-jit-cache/build_backend.py tests/cli/test_cli_cmds.py.python -m pytest tests/cli/test_cli_cmds.py -q(27 passed, with the expected PyTorch GB10 capability warning from this host's torch build).uvx ruff check ...anduvx ruff format --check ...over the touched Python files.git diff --check,bash -n scripts/task_test_jit_cache_package_build_import.sh, and parsed the touched workflow YAML files withyaml.safe_load.sm10x: 8.0 10.0a 10.3a 11.0aandsm12x: 8.0 12.0ffor the CUDA 13.0 release arch list.sm12xand resolvesflashinfer-jit-cache==0.6.11+cu130.sm12xusing theuv pip install --python ...fallback.SM121 target cleanup:
12.1aadditions from release/nightly/default jit-cache arch lists and docs;sm12xnow defaults to8.0 12.0f.12.1asupport in the parser/filter if a user supplies it manually, but release artifacts no longer build it by default.python -m pytest tests/cli/test_cli_cmds.py -q(27 passed),uvx ruff check tests/cli/test_cli_cmds.py,uvx ruff format --check tests/cli/test_cli_cmds.py,bash -n scripts/task_test_jit_cache_package_build_import.sh,git diff --check, workflow YAML parsing, and the DGX Spark CLI dry-run.SM110 architecture split:
11.0a/sm110jit-cache build coverage to CUDA 13.0aarch64release, nightly, and PR AOT build/import arch lists. CUDA 13.0x86_64lists now omit11.0a.11.0afrom the generic x86-oriented examples and call out adding it for Jetson AGX Thor / T5000 aarch64 targets.bash -n scripts/task_test_jit_cache_package_build_import.sh,git diff --check,uvx ruff check tests/cli/test_cli_cmds.py, andpython -m pytest tests/cli/test_cli_cmds.py -q(27 passed).x86_64->7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0f;aarch64->7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f.Summary by CodeRabbit
New Features
Documentation
Tests / CI
Chores