Skip to content

Commit 4393f7d

Browse files
authored
Remove broken MPS build (pytorch#8472)
1 parent b6770a7 commit 4393f7d

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

setup.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import shutil
66
import subprocess
77
import sys
8+
import warnings
89

910
import torch
1011
from pkg_resources import DistributionNotFound, get_distribution, parse_version
@@ -138,7 +139,6 @@ def get_extensions():
138139
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
139140
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
140141
)
141-
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
142142

143143
print("Compiling extensions with following flags:")
144144
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
@@ -204,8 +204,15 @@ def get_extensions():
204204
define_macros += [("WITH_HIP", None)]
205205
nvcc_flags = []
206206
extra_compile_args["nvcc"] = nvcc_flags
207-
elif torch.backends.mps.is_available() or force_mps:
208-
sources += source_mps
207+
208+
# FIXME: MPS build breaks custom ops registration, so it was disabled.
209+
# See https://github.com/pytorch/vision/issues/8456.
210+
# TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V
211+
if force_mps:
212+
warnings.warn("MPS build is temporarily disabled!!!!")
213+
# elif torch.backends.mps.is_available() or force_mps:
214+
# source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
215+
# sources += source_mps
209216

210217
if sys.platform == "win32":
211218
define_macros += [("torchvision_EXPORTS", None)]

0 commit comments

Comments
 (0)