Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions .github/workflows/nightly-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ jobs:
matrix:
cuda: ["12.8", "12.9", "13.0"]
arch: ['x86_64', 'aarch64']
sm_family: ['sm9x', 'sm10x', 'sm12x']

runs-on: [self-hosted, linux, "${{ matrix.arch == 'aarch64' && 'arm64' || 'x64' }}", cpu, on-demand]

Expand All @@ -144,6 +145,7 @@ jobs:
echo "RAM: $(free -h | awk '/^Mem:/ {print $7 " available out of " $2}')"
echo "Disk: $(df -h / | awk 'NR==2 {print $4 " available out of " $2}')"
echo "Architecture: $(uname -m)"
echo "SM family: ${{ matrix.sm_family }}"

- name: Checkout code
uses: actions/checkout@v4
Expand Down Expand Up @@ -176,6 +178,7 @@ jobs:
-e FLASHINFER_DEV_RELEASE_SUFFIX="${FLASHINFER_DEV_RELEASE_SUFFIX}" \
-e ARCH="${{ matrix.arch }}" \
-e FLASHINFER_CUDA_ARCH_LIST="${FLASHINFER_CUDA_ARCH_LIST}" \
-e FLASHINFER_JIT_CACHE_SM_FAMILY="${{ matrix.sm_family }}" \
-w /workspace \
${{ env.DOCKER_IMAGE }} \
bash /workspace/scripts/build_flashinfer_jit_cache_whl.sh
Expand All @@ -188,7 +191,7 @@ jobs:
id: artifact-name
run: |
CUDA_NO_DOT=$(echo "${{ matrix.cuda }}" | tr -d '.')
echo "name=jit-cache-cu${CUDA_NO_DOT}-${{ matrix.arch }}" >> $GITHUB_OUTPUT
echo "name=jit-cache-cu${CUDA_NO_DOT}-${{ matrix.arch }}-${{ matrix.sm_family }}" >> $GITHUB_OUTPUT

- name: Upload flashinfer-jit-cache artifact
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -252,29 +255,32 @@ jobs:
env:
GH_TOKEN: ${{ github.token }}
run: |
# Upload jit-cache wheels one at a time to avoid OOM
# Each wheel can be several GB, so we download, upload, delete, repeat
# Upload jit-cache wheels one at a time to avoid OOM.
# Wheels are split per (CUDA, CPU-arch, SM family) to stay under
# GitHub Releases' 2 GiB per-asset limit.
mkdir -p dist-jit-cache

for cuda in 128 129 130; do
for arch in x86_64 aarch64; do
ARTIFACT_NAME="jit-cache-cu${cuda}-${arch}"
echo "Processing ${ARTIFACT_NAME}..."

# Download this specific artifact
gh run download ${{ github.run_id }} -n "${ARTIFACT_NAME}" -D dist-jit-cache/ || {
echo "Warning: Failed to download ${ARTIFACT_NAME}, skipping..."
continue
}

# Upload to release
if [ -n "$(ls -A dist-jit-cache/)" ]; then
gh release upload "${{ needs.setup.outputs.release_tag }}" dist-jit-cache/* --clobber
echo "✅ Uploaded ${ARTIFACT_NAME}"
fi

# Clean up to save disk space before next iteration
rm -rf dist-jit-cache/*
for sm_family in sm9x sm10x sm12x; do
ARTIFACT_NAME="jit-cache-cu${cuda}-${arch}-${sm_family}"
echo "Processing ${ARTIFACT_NAME}..."

# Download this specific artifact
gh run download ${{ github.run_id }} -n "${ARTIFACT_NAME}" -D dist-jit-cache/ || {
echo "Warning: Failed to download ${ARTIFACT_NAME}, skipping..."
continue
}

# Upload to release
if [ -n "$(ls -A dist-jit-cache/)" ]; then
gh release upload "${{ needs.setup.outputs.release_tag }}" dist-jit-cache/* --clobber
echo "✅ Uploaded ${ARTIFACT_NAME}"
fi

# Clean up to save disk space before next iteration
rm -rf dist-jit-cache/*
done
done
done

Expand Down Expand Up @@ -321,9 +327,11 @@ jobs:
path: dist-cubin/

- name: Download flashinfer-jit-cache artifact
# The test runner is an sm86 (Ampere) host, which falls under the
# sm9x family wheel. Update the family if test runners change.
uses: actions/download-artifact@v4
with:
name: jit-cache-cu${{ matrix.cuda == '12.9' && '129' || '130' }}-x86_64
name: jit-cache-cu${{ matrix.cuda == '12.9' && '129' || '130' }}-x86_64-sm9x
path: dist-jit-cache/

- name: Get Docker image tag
Expand Down
47 changes: 27 additions & 20 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ jobs:
matrix:
cuda: ["12.8", "12.9", "13.0"]
arch: ['x86_64', 'aarch64']
sm_family: ['sm9x', 'sm10x', 'sm12x']

runs-on: [self-hosted, linux, "${{ matrix.arch == 'aarch64' && 'arm64' || 'x64' }}", cpu, on-demand]

Expand All @@ -172,6 +173,7 @@ jobs:
echo "RAM: $(free -h | awk '/^Mem:/ {print $7 " available out of " $2}')"
echo "Disk: $(df -h / | awk 'NR==2 {print $4 " available out of " $2}')"
echo "Architecture: $(uname -m)"
echo "SM family: ${{ matrix.sm_family }}"

- name: Checkout code
uses: actions/checkout@v4
Expand Down Expand Up @@ -203,6 +205,7 @@ jobs:
-e FLASHINFER_LOCAL_VERSION="$FLASHINFER_LOCAL_VERSION" \
-e ARCH="${{ matrix.arch }}" \
-e FLASHINFER_CUDA_ARCH_LIST="${FLASHINFER_CUDA_ARCH_LIST}" \
-e FLASHINFER_JIT_CACHE_SM_FAMILY="${{ matrix.sm_family }}" \
-w /workspace \
${{ env.DOCKER_IMAGE }} \
bash /workspace/scripts/build_flashinfer_jit_cache_whl.sh
Expand All @@ -215,7 +218,7 @@ jobs:
id: artifact-name
run: |
CUDA_NO_DOT=$(echo "${{ matrix.cuda }}" | tr -d '.')
echo "name=jit-cache-cu${CUDA_NO_DOT}-${{ matrix.arch }}" >> $GITHUB_OUTPUT
echo "name=jit-cache-cu${CUDA_NO_DOT}-${{ matrix.arch }}-${{ matrix.sm_family }}" >> $GITHUB_OUTPUT

- name: Upload flashinfer-jit-cache artifact
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -314,29 +317,33 @@ jobs:
env:
GH_TOKEN: ${{ github.token }}
run: |
# Upload jit-cache wheels one at a time to avoid OOM
# Each wheel can be several GB, so we download, upload, delete, repeat
# Upload jit-cache wheels one at a time to avoid OOM.
# Each wheel can be ~1 GB, so we download, upload, delete, repeat.
# Wheels are split per (CUDA, CPU-arch, SM family) to stay under
# GitHub Releases' 2 GiB per-asset limit.
mkdir -p dist-jit-cache

for cuda in 128 129 130; do
for arch in x86_64 aarch64; do
ARTIFACT_NAME="jit-cache-cu${cuda}-${arch}"
echo "Processing ${ARTIFACT_NAME}..."

# Download this specific artifact
gh run download ${{ github.run_id }} -n "${ARTIFACT_NAME}" -D dist-jit-cache/ || {
echo "Warning: Failed to download ${ARTIFACT_NAME}, skipping..."
continue
}

# Upload to release
if [ -n "$(ls -A dist-jit-cache/)" ]; then
gh release upload "${{ needs.setup.outputs.release_tag }}" dist-jit-cache/* --clobber
echo "✅ Uploaded ${ARTIFACT_NAME}"
fi

# Clean up to save disk space before next iteration
rm -rf dist-jit-cache/*
for sm_family in sm9x sm10x sm12x; do
ARTIFACT_NAME="jit-cache-cu${cuda}-${arch}-${sm_family}"
echo "Processing ${ARTIFACT_NAME}..."

# Download this specific artifact
gh run download ${{ github.run_id }} -n "${ARTIFACT_NAME}" -D dist-jit-cache/ || {
echo "Warning: Failed to download ${ARTIFACT_NAME}, skipping..."
continue
}

# Upload to release
if [ -n "$(ls -A dist-jit-cache/)" ]; then
gh release upload "${{ needs.setup.outputs.release_tag }}" dist-jit-cache/* --clobber
echo "✅ Uploaded ${ARTIFACT_NAME}"
fi

# Clean up to save disk space before next iteration
rm -rf dist-jit-cache/*
done
done
done

Expand Down
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,13 @@ pip install flashinfer-python

```bash
pip install flashinfer-python flashinfer-cubin
# JIT cache (replace cu129 with your CUDA version)
pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/cu129
# JIT cache: autodetect CUDA + GPU SM family and run the matching pip install.
# Use --dry-run to preview, --sm-family / --cuda-version to override.
flashinfer install-jit-cache-wheel
```

`flashinfer-jit-cache` is published as separate per-(CUDA, SM family) wheels — `sm9x` (Ampere/Ada/Hopper, ≤sm90), `sm10x` (Datacenter Blackwell, sm100/103/110), and `sm12x` (Consumer Blackwell, sm120/121) — because a single multi-arch wheel exceeds GitHub Releases' 2 GiB asset limit. The CLI resolves the right one for you.

**For Blackwell (SM100+) CuTe DSL kernels**, install with the CUDA 13 extra to enable Blackwell-optimized kernels:

```bash
Expand Down Expand Up @@ -168,6 +171,12 @@ export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f"
cd flashinfer-jit-cache
python -m build --no-isolation --wheel
python -m pip install dist/*.whl

# Or build a per-SM-family wheel matching a release artifact (smaller; one GPU family).
# The local-version on the resulting wheel will include the family (e.g. +cu130.sm10x).
export FLASHINFER_CUDA_ARCH_LIST="10.0a 10.3a 11.0a"
export FLASHINFER_JIT_CACHE_SM_FAMILY="sm10x"
python -m build --no-isolation --wheel
```

For more details, see the [Install from Source documentation](https://docs.flashinfer.ai/installation.html#install-from-source).
Expand All @@ -178,8 +187,8 @@ For more details, see the [Install from Source documentation](https://docs.flash
pip install -U --pre flashinfer-python --index-url https://flashinfer.ai/whl/nightly/ --no-deps
pip install flashinfer-python # Install dependencies from PyPI
pip install -U --pre flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/
# JIT cache (replace cu129 with your CUDA version)
pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129
# JIT cache: autodetect CUDA + GPU and pull from the nightly index
flashinfer install-jit-cache-wheel --nightly
```

### CLI Tools
Expand Down
34 changes: 30 additions & 4 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,26 @@ FlashInfer provides three packages:
.. code-block:: bash

pip install flashinfer-python flashinfer-cubin
# JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130)
pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/cu129
# JIT cache package: autodetects CUDA + GPU SM family and runs the right pip install.
flashinfer install-jit-cache-wheel

This eliminates compilation and downloading overhead at runtime.

``flashinfer-jit-cache`` is published as separate per-(CUDA, SM family) wheels because a
single multi-arch wheel exceeds GitHub Releases' 2 GiB asset limit. The CLI resolves
the right one for you. To pick manually, override the autodetection:

.. code-block:: bash

# Datacenter Blackwell (sm100/103/110), CUDA 13.0
flashinfer install-jit-cache-wheel --cuda-version 13.0 --sm-family sm10x

# Show the pip command without executing
flashinfer install-jit-cache-wheel --dry-run

The SM families are: ``sm9x`` (Ampere/Ada/Hopper, ≤sm90), ``sm10x`` (Datacenter Blackwell,
sm100/103/110), and ``sm12x`` (Consumer Blackwell, sm120/121).


.. _install-from-source:

Expand Down Expand Up @@ -112,6 +127,17 @@ You can follow the steps below to install FlashInfer from source code:
python -m build --no-isolation --wheel
python -m pip install dist/*.whl

To build a per-SM-family wheel matching a release artifact (smaller; targets one
GPU family), set ``FLASHINFER_JIT_CACHE_SM_FAMILY``. The local-version suffix on
the resulting wheel will include the family (e.g. ``+cu130.sm10x``).

.. code-block:: bash

export FLASHINFER_CUDA_ARCH_LIST="10.0a 10.3a 11.0a"
export FLASHINFER_JIT_CACHE_SM_FAMILY="sm10x"
cd flashinfer-jit-cache
python -m build --no-isolation --wheel


Install Nightly Build
^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -124,8 +150,8 @@ Nightly builds are available for testing the latest features:
pip install -U --pre flashinfer-python --index-url https://flashinfer.ai/whl/nightly/ --no-deps # Install the nightly package from custom index, without installing dependencies
pip install flashinfer-python # Install flashinfer-python's dependencies from PyPI
pip install -U --pre flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/
# JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130)
pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129
# JIT cache package: autodetect CUDA + GPU and pull from the nightly index
flashinfer install-jit-cache-wheel --nightly

Verify Installation
^^^^^^^^^^^^^^^^^^^
Expand Down
78 changes: 78 additions & 0 deletions flashinfer-jit-cache/build_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,71 @@
os.environ["FLASHINFER_DISABLE_VERSION_CHECK"] = "1"


# SM family → arch list filter. Each entry is the predicate `(major, minor_str) -> bool`
# applied to entries in FLASHINFER_CUDA_ARCH_LIST. The local-version suffix encodes the
# family so users (and pip) can resolve the right wheel: e.g. "0.6.11+cu130.sm10x".
#
# Keep this in sync with `_detect_sm_family` in flashinfer/__main__.py.
SM_FAMILIES = {
"sm9x": lambda major, minor: major < 10, # 7.5 / 8.0 / 8.9 / 9.0a (Ampere/Ada/Hopper)
"sm10x": lambda major, minor: 10 <= major < 12, # 10.0a / 10.3a / 11.0a (Datacenter Blackwell)
"sm12x": lambda major, minor: major >= 12, # 12.0f / 12.1a (Consumer Blackwell)
}
Comment thread
dierksen marked this conversation as resolved.
Outdated


def _filter_arch_list_for_family(arch_list: str, family: str) -> str:
"""Filter a space-separated FLASHINFER_CUDA_ARCH_LIST to only entries belonging to `family`."""
predicate = SM_FAMILIES[family]
kept = []
for entry in arch_list.split():
major_str, minor_str = entry.split(".", 1)
Comment thread
dierksen marked this conversation as resolved.
Outdated
major = int(major_str)
# `minor_str` may carry a suffix like 'a' or 'f' — keep the whole string for output,
# but parse leading digits for the comparison.
leading_digits = "".join(c for c in minor_str if c.isdigit())
minor = int(leading_digits) if leading_digits else 0
if predicate(major, minor):
kept.append(entry)
return " ".join(kept)


def _resolve_sm_family() -> str:
"""Return the SM family this build targets, or '' for a legacy multi-family build."""
family = os.environ.get("FLASHINFER_JIT_CACHE_SM_FAMILY", "").strip().lower()
if not family:
return ""
if family not in SM_FAMILIES:
raise RuntimeError(
f"Invalid FLASHINFER_JIT_CACHE_SM_FAMILY={family!r}. "
f"Expected one of: {sorted(SM_FAMILIES)}"
)
return family


def _apply_sm_family_filter() -> str:
"""If a family is selected, narrow FLASHINFER_CUDA_ARCH_LIST in-place and return the family suffix."""
family = _resolve_sm_family()
if not family:
return ""

arch_list = os.environ.get("FLASHINFER_CUDA_ARCH_LIST")
if not arch_list:
# The downstream build will fail with a clear error in compile_and_package_modules;
# we let it raise there to keep error messages consistent.
return family

filtered = _filter_arch_list_for_family(arch_list, family)
if not filtered:
raise RuntimeError(
f"FLASHINFER_JIT_CACHE_SM_FAMILY={family} but FLASHINFER_CUDA_ARCH_LIST="
f"{arch_list!r} contains no archs in that family. "
f"Set FLASHINFER_CUDA_ARCH_LIST to include archs matching {family}."
)
print(f"SM family {family}: filtering FLASHINFER_CUDA_ARCH_LIST {arch_list!r} -> {filtered!r}")
os.environ["FLASHINFER_CUDA_ARCH_LIST"] = filtered
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
return family


def _create_build_metadata():
"""Create build metadata file with version information."""
version_file = Path(__file__).parent.parent / "version.txt"
Expand All @@ -49,6 +114,13 @@ def _create_build_metadata():

# Append local version suffix if available
local_version = os.environ.get("FLASHINFER_LOCAL_VERSION")

# When this build targets a single SM family, append it to the local-version
# so users can pin e.g. "flashinfer-jit-cache==0.6.11+cu130.sm10x".
family = _resolve_sm_family()
if family:
local_version = f"{local_version}.{family}" if local_version else family

if local_version:
# Use + to create a local version identifier that will appear in wheel name
version = f"{version}+{local_version}"
Expand Down Expand Up @@ -157,6 +229,12 @@ def _build_aot_modules():

def _prepare_build():
"""Shared preparation logic for both wheel and editable builds."""
_apply_sm_family_filter()
# Re-derive the build metadata so the family suffix is reflected in
# `_build_meta.py` even when the import-time call ran before the env var
# was set (e.g. when callers configure FLASHINFER_JIT_CACHE_SM_FAMILY in
# the same process before invoking build_wheel).
_create_build_metadata()
_build_aot_modules()


Expand Down
Loading
Loading