Skip to content

Commit 05cbea1

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Hipify Pytorch3D (#1851)
Summary: X-link: pytorch/pytorch#133343 X-link: fairinternal/pytorch3d#45 Pull Request resolved: #1851 Enables pytorch3d to build on AMD. An important part of enabling this was not compiling the Pulsar backend when the target is AMD. There are simply too many kernel incompatibilites to make it work (I tried haha). Fortunately, it doesnt seem like most modern applications of pytorch3d rely on Pulsar. We should be able to unlock most of pytorch3d's goodness on AMD without it. Reviewed By: bottler, houseroad Differential Revision: D61171993 fbshipit-source-id: fd4aee378a3568b22676c5bf2b727c135ff710af
1 parent 38afdcf commit 05cbea1

File tree

6 files changed

+50
-3
lines changed

6 files changed

+50
-3
lines changed

pytorch3d/csrc/ext.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
*/
88

99
// clang-format off
10+
#if !defined(USE_ROCM)
1011
#include "./pulsar/global.h" // Include before <torch/extension.h>.
12+
#endif
1113
#include <torch/extension.h>
1214
// clang-format on
15+
#if !defined(USE_ROCM)
1316
#include "./pulsar/pytorch/renderer.h"
1417
#include "./pulsar/pytorch/tensor_util.h"
18+
#endif
1519
#include "ball_query/ball_query.h"
1620
#include "blending/sigmoid_alpha_blend.h"
1721
#include "compositing/alpha_composite.h"
@@ -99,6 +103,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
99103
m.def("marching_cubes", &MarchingCubes);
100104

101105
// Pulsar.
106+
// Pulsar not enabled on AMD.
107+
#if !defined(USE_ROCM)
102108
#ifdef PULSAR_LOGGING_ENABLED
103109
c10::ShowLogInfoToStderr();
104110
#endif
@@ -183,4 +189,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
183189
m.attr("MAX_UINT") = py::int_(MAX_UINT);
184190
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
185191
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
192+
#endif
186193
}

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ __device__ void CheckPixelInsideFace(
144144
const bool zero_face_area =
145145
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
146146

147-
if (zmax < 0 || cull_backfaces && back_face || outside_bbox ||
147+
if (zmax < 0 || (cull_backfaces && back_face) || outside_bbox ||
148148
zero_face_area) {
149149
return;
150150
}

pytorch3d/csrc/utils/float_math.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ const auto vEpsilon = 1e-8;
1818

1919
// Common functions and operators for float2.
2020

21+
// Complex arithmetic is already defined for AMD.
22+
#if !defined(USE_ROCM)
2123
__device__ inline float2 operator-(const float2& a, const float2& b) {
2224
return make_float2(a.x - b.x, a.y - b.y);
2325
}
@@ -41,6 +43,7 @@ __device__ inline float2 operator*(const float2& a, const float2& b) {
4143
__device__ inline float2 operator*(const float a, const float2& b) {
4244
return make_float2(a * b.x, a * b.y);
4345
}
46+
#endif
4447

4548
__device__ inline float FloatMin3(const float a, const float b, const float c) {
4649
return fminf(a, fminf(b, c));

pytorch3d/csrc/utils/warp_reduce.cuh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,51 @@ WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) {
2323
min_idxs[tid] = min_idxs[tid + 32];
2424
min_dists[tid] = min_dists[tid + 32];
2525
}
26+
// AMD does not use explicit syncwarp and instead automatically inserts memory
27+
// fences during compilation.
28+
#if !defined(USE_ROCM)
2629
__syncwarp();
30+
#endif
2731
// s = 16
2832
if (min_dists[tid] > min_dists[tid + 16]) {
2933
min_idxs[tid] = min_idxs[tid + 16];
3034
min_dists[tid] = min_dists[tid + 16];
3135
}
36+
#if !defined(USE_ROCM)
3237
__syncwarp();
38+
#endif
3339
// s = 8
3440
if (min_dists[tid] > min_dists[tid + 8]) {
3541
min_idxs[tid] = min_idxs[tid + 8];
3642
min_dists[tid] = min_dists[tid + 8];
3743
}
44+
#if !defined(USE_ROCM)
3845
__syncwarp();
46+
#endif
3947
// s = 4
4048
if (min_dists[tid] > min_dists[tid + 4]) {
4149
min_idxs[tid] = min_idxs[tid + 4];
4250
min_dists[tid] = min_dists[tid + 4];
4351
}
52+
#if !defined(USE_ROCM)
4453
__syncwarp();
54+
#endif
4555
// s = 2
4656
if (min_dists[tid] > min_dists[tid + 2]) {
4757
min_idxs[tid] = min_idxs[tid + 2];
4858
min_dists[tid] = min_dists[tid + 2];
4959
}
60+
#if !defined(USE_ROCM)
5061
__syncwarp();
62+
#endif
5163
// s = 1
5264
if (min_dists[tid] > min_dists[tid + 1]) {
5365
min_idxs[tid] = min_idxs[tid + 1];
5466
min_dists[tid] = min_dists[tid + 1];
5567
}
68+
#if !defined(USE_ROCM)
5669
__syncwarp();
70+
#endif
5771
}
5872

5973
template <typename scalar_t>
@@ -65,30 +79,42 @@ __device__ void WarpReduceMax(
6579
dists[tid] = dists[tid + 32];
6680
dists_idx[tid] = dists_idx[tid + 32];
6781
}
82+
#if !defined(USE_ROCM)
6883
__syncwarp();
84+
#endif
6985
if (dists[tid] < dists[tid + 16]) {
7086
dists[tid] = dists[tid + 16];
7187
dists_idx[tid] = dists_idx[tid + 16];
7288
}
89+
#if !defined(USE_ROCM)
7390
__syncwarp();
91+
#endif
7492
if (dists[tid] < dists[tid + 8]) {
7593
dists[tid] = dists[tid + 8];
7694
dists_idx[tid] = dists_idx[tid + 8];
7795
}
96+
#if !defined(USE_ROCM)
7897
__syncwarp();
98+
#endif
7999
if (dists[tid] < dists[tid + 4]) {
80100
dists[tid] = dists[tid + 4];
81101
dists_idx[tid] = dists_idx[tid + 4];
82102
}
103+
#if !defined(USE_ROCM)
83104
__syncwarp();
105+
#endif
84106
if (dists[tid] < dists[tid + 2]) {
85107
dists[tid] = dists[tid + 2];
86108
dists_idx[tid] = dists_idx[tid + 2];
87109
}
110+
#if !defined(USE_ROCM)
88111
__syncwarp();
112+
#endif
89113
if (dists[tid] < dists[tid + 1]) {
90114
dists[tid] = dists[tid + 1];
91115
dists_idx[tid] = dists_idx[tid + 1];
92116
}
117+
#if !defined(USE_ROCM)
93118
__syncwarp();
119+
#endif
94120
}

pytorch3d/renderer/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-unsafe
88

9+
import torch
10+
911
from .blending import (
1012
BlendParams,
1113
hard_rgb_blend,
@@ -74,9 +76,13 @@
7476
PointsRasterizationSettings,
7577
PointsRasterizer,
7678
PointsRenderer,
77-
PulsarPointsRenderer,
7879
rasterize_points,
7980
)
81+
82+
# Pulsar is not enabled on amd.
83+
if not torch.version.hip:
84+
from .points import PulsarPointsRenderer
85+
8086
from .splatter_blend import SplatterBlender
8187
from .utils import (
8288
convert_to_tensors_and_broadcast,

pytorch3d/renderer/points/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66

77
# pyre-unsafe
88

9+
import torch
10+
911
from .compositor import AlphaCompositor, NormWeightedCompositor
10-
from .pulsar.unified import PulsarPointsRenderer
12+
13+
# Pulsar not enabled on amd.
14+
if not torch.version.hip:
15+
from .pulsar.unified import PulsarPointsRenderer
1116

1217
from .rasterize_points import rasterize_points
1318
from .rasterizer import PointsRasterizationSettings, PointsRasterizer

0 commit comments

Comments
 (0)