Skip to content

Commit dbdc726

Browse files
authored
fix GET_THREADS() for ROCm (#2997)
1 parent 8e878f0 commit dbdc726

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchvision/csrc/cuda/DeformConv_cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@
8181
const int kMaxParallelImgs = 32;
8282

8383
inline unsigned int GET_THREADS() {
84+
#ifdef __HIP_PLATFORM_HCC__
85+
return 256;
86+
#endif
8487
if (at::cuda::getCurrentDeviceProperties()->major >= 6) {
8588
return 1024;
8689
}

0 commit comments

Comments
 (0)