We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fb115c2 commit a3ba01aCopy full SHA for a3ba01a
.github/scripts/setup-env.sh
@@ -73,11 +73,21 @@ else
73
CHANNEL=nightly
74
fi
75
76
-pip install --progress-bar=off light-the-torch
77
-ltt install --progress-bar=off \
78
- --pytorch-computation-backend="${GPU_ARCH_TYPE}${GPU_ARCH_VERSION}" \
79
- --pytorch-channel="${CHANNEL}" \
80
- torch
+case $GPU_ARCH_TYPE in
+ cpu)
+ GPU_ARCH_ID="cpu"
+ ;;
+ cuda)
81
+ VERSION_WITHOUT_DOT=$(echo "${GPU_ARCH_VERSION}" | sed 's/\.//')
82
+ GPU_ARCH_ID="cu${VERSION_WITHOUT_DOT}"
83
84
+ *)
85
+ echo "Unknown GPU_ARCH_TYPE=${GPU_ARCH_TYPE}"
86
+ exit 1
87
88
+esac
89
+PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${CHANNEL}/${GPU_ARCH_ID}"
90
+pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}"
91
92
if [[ $GPU_ARCH_TYPE == 'cuda' ]]; then
93
python -c "import torch; exit(not torch.cuda.is_available())"
0 commit comments