Skip to content

Commit 25120fb

Browse files
parmeetfmassa
authored andcommitted
[fbsync] Use TORCH_CUDA_ARCH_LIST to specify which CUDA architectures to build for (#3399)
Summary: * trying stuff * Put flags back? * remove second one just to see * It worked but it's because it's not built by CI. Trying another one * Try using TORCH_CUDA_ARCH_LIST instead * oops * Add new env variable to build/script_env * set TORCH_CUDA_ARCH_LIST for the rest of the CUDA versions * don't pass NVCC_FLAGS venv to conda, let's see if it works Reviewed By: fmassa Differential Revision: D27433927 fbshipit-source-id: e207f7fe6a1bab322e41d6dcabe7eb2b8e70b246 Co-authored-by: Francisco Massa <[email protected]>
1 parent d243606 commit 25120fb

File tree

2 files changed

+8
-20
lines changed

2 files changed

+8
-20
lines changed

packaging/pkg_helpers.bash

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ setup_cuda() {
5656
export CUDA_HOME=/usr/local/cuda-11.2/
5757
fi
5858
export FORCE_CUDA=1
59-
# Hard-coding gencode flags is temporary situation until
60-
# https://github.com/pytorch/pytorch/pull/23408 lands
61-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50"
59+
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
6260
;;
6361
cu111)
6462
if [[ "$OSTYPE" == "msys" ]]; then
@@ -67,9 +65,7 @@ setup_cuda() {
6765
export CUDA_HOME=/usr/local/cuda-11.1/
6866
fi
6967
export FORCE_CUDA=1
70-
# Hard-coding gencode flags is temporary situation until
71-
# https://github.com/pytorch/pytorch/pull/23408 lands
72-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50"
68+
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
7369
;;
7470
cu110)
7571
if [[ "$OSTYPE" == "msys" ]]; then
@@ -78,9 +74,7 @@ setup_cuda() {
7874
export CUDA_HOME=/usr/local/cuda-11.0/
7975
fi
8076
export FORCE_CUDA=1
81-
# Hard-coding gencode flags is temporary situation until
82-
# https://github.com/pytorch/pytorch/pull/23408 lands
83-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_50,code=compute_50"
77+
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0"
8478
;;
8579
cu102)
8680
if [[ "$OSTYPE" == "msys" ]]; then
@@ -89,9 +83,7 @@ setup_cuda() {
8983
export CUDA_HOME=/usr/local/cuda-10.2/
9084
fi
9185
export FORCE_CUDA=1
92-
# Hard-coding gencode flags is temporary situation until
93-
# https://github.com/pytorch/pytorch/pull/23408 lands
94-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50"
86+
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5"
9587
;;
9688
cu101)
9789
if [[ "$OSTYPE" == "msys" ]]; then
@@ -100,9 +92,7 @@ setup_cuda() {
10092
export CUDA_HOME=/usr/local/cuda-10.1/
10193
fi
10294
export FORCE_CUDA=1
103-
# Hard-coding gencode flags is temporary situation until
104-
# https://github.com/pytorch/pytorch/pull/23408 lands
105-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50"
95+
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5"
10696
;;
10797
cu100)
10898
if [[ "$OSTYPE" == "msys" ]]; then
@@ -111,9 +101,7 @@ setup_cuda() {
111101
export CUDA_HOME=/usr/local/cuda-10.0/
112102
fi
113103
export FORCE_CUDA=1
114-
# Hard-coding gencode flags is temporary situation until
115-
# https://github.com/pytorch/pytorch/pull/23408 lands
116-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50"
104+
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5"
117105
;;
118106
cu92)
119107
if [[ "$OSTYPE" == "msys" ]]; then
@@ -122,7 +110,7 @@ setup_cuda() {
122110
export CUDA_HOME=/usr/local/cuda-9.2/
123111
fi
124112
export FORCE_CUDA=1
125-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_50,code=compute_50"
113+
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0"
126114
;;
127115
cpu)
128116
;;

packaging/torchvision/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ build:
3737
script_env:
3838
- CUDA_HOME
3939
- FORCE_CUDA
40-
- NVCC_FLAGS
4140
- BUILD_VERSION
41+
- TORCH_CUDA_ARCH_LIST
4242
features:
4343
{{ environ.get('CONDA_CPUONLY_FEATURE') }}
4444

0 commit comments

Comments
 (0)