Skip to content

Commit 705d15a

Browse files
committed
Thor & Spark Support
1 parent da01b1b commit 705d15a

5 files changed

Lines changed: 7 additions & 4 deletions

File tree

.github/workflows/nightly-release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ jobs:
145145
- name: Build wheel in container
146146
env:
147147
DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }}
148-
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }}
148+
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0a 12.1a' }}
149149
FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }}
150150
run: |
151151
# Extract CUDA major and minor versions

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ jobs:
182182
- name: Build wheel in container
183183
env:
184184
DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }}
185-
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }}
185+
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0a 12.1a' }}
186186
run: |
187187
# Extract CUDA major and minor versions
188188
CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ python -m pip install dist/*.whl
9090

9191
`flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs):
9292
```bash
93-
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a"
93+
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0a 12.0a 12.1a"
9494
cd flashinfer-jit-cache
9595
python -m build --no-isolation --wheel
9696
python -m pip install dist/*.whl

docs/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ You can follow the steps below to install FlashInfer from source code:
9292

9393
.. code-block:: bash
9494
95-
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a"
95+
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0a 12.0a 12.1a"
9696
cd flashinfer-jit-cache
9797
python -m build --no-isolation --wheel
9898
python -m pip install dist/*.whl

scripts/task_test_jit_cache_package_build_import.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ arches = ["7.5", "8.0", "8.9", "9.0a"]
3737
if cuda_ver is not None:
3838
try:
3939
major, minor = map(int, cuda_ver.split(".")[:2])
40+
if (major, minor) >= (13, 0):
41+
arches.append("11.0a")
42+
arches.append("12.1a")
4043
if (major, minor) >= (12, 8):
4144
arches.append("10.0a")
4245
arches.append("12.0a")

0 commit comments

Comments
 (0)