Skip to content

Commit f96c42f

Browse files
huydhnNicolasHug
andauthored
Re-enable vision MPS builds (#8485)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent f1bcbd3 commit f96c42f

File tree

4 files changed

+6
-14
lines changed

4 files changed

+6
-14
lines changed

.github/workflows/build-cmake.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
export GPU_ARCH_TYPE=cpu
5555
export GPU_ARCH_VERSION=''
5656
57-
./.github/scripts/cmake.sh
57+
${CONDA_RUN} ./.github/scripts/cmake.sh
5858
5959
windows:
6060
strategy:

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
export GPU_ARCH_TYPE=cpu
6969
export GPU_ARCH_VERSION=''
7070
71-
./.github/scripts/unittest.sh
71+
${CONDA_RUN} ./.github/scripts/unittest.sh
7272
7373
unittests-windows:
7474
strategy:

setup.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import shutil
66
import subprocess
77
import sys
8-
import warnings
98

109
import torch
1110
from pkg_resources import DistributionNotFound, get_distribution, parse_version
@@ -139,6 +138,7 @@ def get_extensions():
139138
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
140139
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
141140
)
141+
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
142142

143143
print("Compiling extensions with following flags:")
144144
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
@@ -204,15 +204,8 @@ def get_extensions():
204204
define_macros += [("WITH_HIP", None)]
205205
nvcc_flags = []
206206
extra_compile_args["nvcc"] = nvcc_flags
207-
208-
# FIXME: MPS build breaks custom ops registration, so it was disabled.
209-
# See https://github.com/pytorch/vision/issues/8456.
210-
# TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V
211-
if force_mps:
212-
warnings.warn("MPS build is temporarily disabled!!!!")
213-
# elif torch.backends.mps.is_available() or force_mps:
214-
# source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
215-
# sources += source_mps
207+
elif torch.backends.mps.is_available() or force_mps:
208+
sources += source_mps
216209

217210
if sys.platform == "win32":
218211
define_macros += [("torchvision_EXPORTS", None)]

test/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def pytest_collection_modifyitems(items):
4949
# There are special cases though, see below
5050
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))
5151

52-
# TODO: uncoment when MPS works again - see FIXME in setup.py
53-
if needs_mps: # and not torch.backends.mps.is_available():
52+
if needs_mps and not torch.backends.mps.is_available():
5453
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))
5554

5655
if IN_FBCODE:

0 commit comments

Comments
 (0)