|
5 | 5 | import shutil
|
6 | 6 | import subprocess
|
7 | 7 | import sys
|
| 8 | +import warnings |
8 | 9 |
|
9 | 10 | import torch
|
10 | 11 | from pkg_resources import DistributionNotFound, get_distribution, parse_version
|
@@ -138,7 +139,6 @@ def get_extensions():
|
138 | 139 | + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
|
139 | 140 | + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
|
140 | 141 | )
|
141 |
| - source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) |
142 | 142 |
|
143 | 143 | print("Compiling extensions with following flags:")
|
144 | 144 | force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
|
@@ -204,8 +204,15 @@ def get_extensions():
|
204 | 204 | define_macros += [("WITH_HIP", None)]
|
205 | 205 | nvcc_flags = []
|
206 | 206 | extra_compile_args["nvcc"] = nvcc_flags
|
207 |
| - elif torch.backends.mps.is_available() or force_mps: |
208 |
| - sources += source_mps |
| 207 | + |
| 208 | + # FIXME: MPS build breaks custom ops registration, so it was disabled. |
| 209 | + # See https://github.com/pytorch/vision/issues/8456. |
| 210 | + # TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V |
| 211 | + if force_mps: |
| 212 | + warnings.warn("MPS build is temporarily disabled!!!!") |
| 213 | + # elif torch.backends.mps.is_available() or force_mps: |
| 214 | + # source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) |
| 215 | + # sources += source_mps |
209 | 216 |
|
210 | 217 | if sys.platform == "win32":
|
211 | 218 | define_macros += [("torchvision_EXPORTS", None)]
|
|
0 commit comments