|
48 | 48 |
|
49 | 49 | let |
50 | 50 | inherit (lib) attrsets lists strings trivial; |
51 | | - inherit (cudaPackages) cudaFlags cudnn nccl; |
| 51 | + inherit (cudaPackages) cudaFlags cudnn; |
| 52 | + |
| 53 | + # Some packages are not available on all platforms |
| 54 | + nccl = cudaPackages.nccl or null; |
52 | 55 |
|
53 | 56 | setBool = v: if v then "1" else "0"; |
54 | 57 |
|
@@ -178,6 +181,13 @@ in buildPythonPackage rec { |
178 | 181 | 'message(FATAL_ERROR "Found NCCL header version and library version' \ |
179 | 182 | 'message(WARNING "Found NCCL header version and library version' |
180 | 183 | '' |
| 184 | + # TODO(@connorbaker): Remove this patch after 2.1.0 lands. |
| 185 | + + lib.optionalString cudaSupport '' |
| 186 | + substituteInPlace torch/utils/cpp_extension.py \ |
| 187 | + --replace \ |
| 188 | + "'8.6', '8.9'" \ |
| 189 | + "'8.6', '8.7', '8.9'" |
| 190 | + '' |
181 | 191 | # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc' |
182 | 192 | # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header. |
183 | 193 | + lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") '' |
@@ -253,6 +263,7 @@ in buildPythonPackage rec { |
253 | 263 | PYTORCH_BUILD_VERSION = version; |
254 | 264 | PYTORCH_BUILD_NUMBER = 0; |
255 | 265 |
|
| 266 | + USE_NCCL = setBool (nccl != null); |
256 | 267 | USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL |
257 | 268 | USE_STATIC_NCCL = setBool useSystemNccl; |
258 | 269 |
|
@@ -316,6 +327,8 @@ in buildPythonPackage rec { |
316 | 327 | libcusolver.lib |
317 | 328 | libcusparse.dev |
318 | 329 | libcusparse.lib |
| 330 | + ] ++ lists.optionals (nccl != null) [ |
| 331 | + # Some platforms do not support NCCL (i.e., Jetson) |
319 | 332 | nccl.dev # Provides nccl.h AND a static copy of NCCL! |
320 | 333 | ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ |
321 | 334 | cuda_nvprof.dev # <cuda_profiler_api.h> |
|
0 commit comments