Skip to content

Commit 47f07ca

Browse files
author
Connor Baker
authored
Merge pull request #266081 from ConnorBaker/fix/torch-jetson
python3Packages.torch: patch cpp_extension.py for Jetson support
2 parents 417c205 + 2a42503 commit 47f07ca

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

pkgs/development/python-modules/torch/default.nix

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@
4848

4949
let
5050
inherit (lib) attrsets lists strings trivial;
51-
inherit (cudaPackages) cudaFlags cudnn nccl;
51+
inherit (cudaPackages) cudaFlags cudnn;
52+
53+
# Some packages are not available on all platforms
54+
nccl = cudaPackages.nccl or null;
5255

5356
setBool = v: if v then "1" else "0";
5457

@@ -178,6 +181,13 @@ in buildPythonPackage rec {
178181
'message(FATAL_ERROR "Found NCCL header version and library version' \
179182
'message(WARNING "Found NCCL header version and library version'
180183
''
184+
# TODO(@connorbaker): Remove this patch after 2.1.0 lands.
185+
+ lib.optionalString cudaSupport ''
186+
substituteInPlace torch/utils/cpp_extension.py \
187+
--replace \
188+
"'8.6', '8.9'" \
189+
"'8.6', '8.7', '8.9'"
190+
''
181191
# error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
182192
# This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
183193
+ lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") ''
@@ -253,6 +263,7 @@ in buildPythonPackage rec {
253263
PYTORCH_BUILD_VERSION = version;
254264
PYTORCH_BUILD_NUMBER = 0;
255265

266+
USE_NCCL = setBool (nccl != null);
256267
USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL
257268
USE_STATIC_NCCL = setBool useSystemNccl;
258269

@@ -316,6 +327,8 @@ in buildPythonPackage rec {
316327
libcusolver.lib
317328
libcusparse.dev
318329
libcusparse.lib
330+
] ++ lists.optionals (nccl != null) [
331+
# Some platforms do not support NCCL (i.e., Jetson)
319332
nccl.dev # Provides nccl.h AND a static copy of NCCL!
320333
] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
321334
cuda_nvprof.dev # <cuda_profiler_api.h>

0 commit comments

Comments
 (0)