4
4
#include < tuple>
5
5
6
6
template <typename scalar_t >
7
- __global__ void FaceAreasNormalsKernel (
7
+ __global__ void FaceAreasNormalsForwardKernel (
8
8
const scalar_t * __restrict__ verts,
9
- const long * __restrict__ faces,
9
+ const int64_t * __restrict__ faces,
10
10
scalar_t * __restrict__ face_areas,
11
11
scalar_t * __restrict__ face_normals,
12
12
const size_t V,
@@ -18,9 +18,9 @@ __global__ void FaceAreasNormalsKernel(
18
18
// Each thread computes the area & normal of its respective faces and adds it
19
19
// to the global face_areas tensor.
20
20
for (size_t f = tid; f < F; f += stride) {
21
- const long i0 = faces[3 * f + 0 ];
22
- const long i1 = faces[3 * f + 1 ];
23
- const long i2 = faces[3 * f + 2 ];
21
+ const int64_t i0 = faces[3 * f + 0 ];
22
+ const int64_t i1 = faces[3 * f + 1 ];
23
+ const int64_t i2 = faces[3 * f + 2 ];
24
24
25
25
const scalar_t v0_x = verts[3 * i0 + 0 ];
26
26
const scalar_t v0_y = verts[3 * i0 + 1 ];
@@ -55,9 +55,161 @@ __global__ void FaceAreasNormalsKernel(
55
55
}
56
56
}
57
57
58
- std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda (
59
- at::Tensor verts,
60
- at::Tensor faces) {
58
+ // TODO(gkioxari) support all data types once AtomicAdd supports doubles.
59
+ // Currently, support is for floats only.
60
+ __global__ void FaceAreasNormalsBackwardKernel (
61
+ const float * __restrict__ grad_areas,
62
+ const float * __restrict__ grad_normals,
63
+ const float * __restrict__ verts,
64
+ const int64_t * __restrict__ faces,
65
+ float * __restrict__ grad_verts,
66
+ const size_t V,
67
+ const size_t F) {
68
+ const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
69
+ const size_t stride = gridDim .x * blockDim .x ;
70
+
71
+ // Faces split evenly over the number of threads in the grid.
72
+ // Each thread computes the area & normal of its respective faces and adds it
73
+ // to the global face_areas tensor.
74
+ for (size_t f = tid; f < F; f += stride) {
75
+ const int64_t i0 = faces[3 * f + 0 ];
76
+ const int64_t i1 = faces[3 * f + 1 ];
77
+ const int64_t i2 = faces[3 * f + 2 ];
78
+
79
+ const float v0_x = verts[3 * i0 + 0 ];
80
+ const float v0_y = verts[3 * i0 + 1 ];
81
+ const float v0_z = verts[3 * i0 + 2 ];
82
+
83
+ const float v1_x = verts[3 * i1 + 0 ];
84
+ const float v1_y = verts[3 * i1 + 1 ];
85
+ const float v1_z = verts[3 * i1 + 2 ];
86
+
87
+ const float v2_x = verts[3 * i2 + 0 ];
88
+ const float v2_y = verts[3 * i2 + 1 ];
89
+ const float v2_z = verts[3 * i2 + 2 ];
90
+
91
+ const float ax = v1_x - v0_x;
92
+ const float ay = v1_y - v0_y;
93
+ const float az = v1_z - v0_z;
94
+
95
+ const float bx = v2_x - v0_x;
96
+ const float by = v2_y - v0_y;
97
+ const float bz = v2_z - v0_z;
98
+
99
+ const float cx = ay * bz - az * by;
100
+ const float cy = az * bx - ax * bz;
101
+ const float cz = ax * by - ay * bx;
102
+
103
+ float norm = sqrt (cx * cx + cy * cy + cz * cz);
104
+ norm = (norm < 1e-6 ) ? 1e-6 : norm; // max(norm, 1e-6)
105
+ float inv_norm = 1 . / norm;
106
+ float inv_norm_2 = pow (inv_norm, 2 .0f );
107
+ float inv_norm_3 = pow (inv_norm, 3 .0f );
108
+
109
+ // We compute gradients with respect to the input vertices.
110
+ // For each vertex, gradients come from grad_areas and grad_normals.
111
+ // eg, grad_v0_x = (d / d v0_x)
112
+ // = \sum_f (d / d areas[f]) * (d areas[f] / d v0_x)
113
+ // + (d / d normals[f, 0]) * (d normals[f, 0] / d v0_x)
114
+ // + (d / d normals[f, 1]) * (d normals[f, 1] / d v0_x)
115
+ // + (d / d normals[f, 2]) * (d normals[f, 2] / d v0_x)
116
+ // with (d / d areas[f]) = grad_areas[f] and
117
+ // (d / d normals[f, j]) = grad_normals[f][j].
118
+ // The equations below are derived after taking
119
+ // derivatives wrt to the vertices (fun times!).
120
+
121
+ // grad v0 coming from grad areas and grad normals
122
+ const float grad_v0_x =
123
+ ((-az + bz) * cy + (-by + ay) * cz) / 2.0 * inv_norm * grad_areas[f] +
124
+ -cx * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_3 *
125
+ grad_normals[3 * f + 0 ] +
126
+ ((-az + bz) - cy * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
127
+ inv_norm * grad_normals[3 * f + 1 ] +
128
+ ((-by + ay) - cz * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
129
+ inv_norm * grad_normals[3 * f + 2 ];
130
+ atomicAdd (grad_verts + 3 * i0 + 0 , grad_v0_x);
131
+
132
+ const float grad_v0_y =
133
+ ((-bz + az) * cx + (-ax + bx) * cz) / 2.0 * inv_norm * grad_areas[f] +
134
+ ((-bz + az) - cx * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
135
+ inv_norm * grad_normals[3 * f + 0 ] +
136
+ -cy * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_3 *
137
+ grad_normals[3 * f + 1 ] +
138
+ ((-ax + bx) - cz * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
139
+ inv_norm * grad_normals[3 * f + 2 ];
140
+ atomicAdd (grad_verts + 3 * i0 + 1 , grad_v0_y);
141
+
142
+ const float grad_v0_z =
143
+ ((-ay + by) * cx + (-bx + ax) * cy) / 2.0 * inv_norm * grad_areas[f] +
144
+ ((-ay + by) - cx * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
145
+ inv_norm * grad_normals[3 * f + 0 ] +
146
+ ((-bx + ax) - cy * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
147
+ inv_norm * grad_normals[3 * f + 1 ] +
148
+ -cz * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_3 *
149
+ grad_normals[3 * f + 2 ];
150
+ atomicAdd (grad_verts + 3 * i0 + 2 , grad_v0_z);
151
+
152
+ // grad v1 coming from grad areas and grad normals
153
+ const float grad_v1_x =
154
+ (by * cz - bz * cy) / 2.0 * inv_norm * grad_areas[f] +
155
+ -cx * (by * cz - bz * cy) * inv_norm_3 * grad_normals[3 * f + 0 ] +
156
+ (-bz - cy * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
157
+ grad_normals[3 * f + 1 ] +
158
+ (by - cz * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
159
+ grad_normals[3 * f + 2 ];
160
+ atomicAdd (grad_verts + 3 * i1 + 0 , grad_v1_x);
161
+
162
+ const float grad_v1_y =
163
+ (bz * cx - bx * cz) / 2.0 * inv_norm * grad_areas[f] +
164
+ (bz - cx * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
165
+ grad_normals[3 * f + 0 ] +
166
+ -cy * (bz * cx - bx * cz) * inv_norm_3 * grad_normals[3 * f + 1 ] +
167
+ (-bx - cz * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
168
+ grad_normals[3 * f + 2 ];
169
+ atomicAdd (grad_verts + 3 * i1 + 1 , grad_v1_y);
170
+
171
+ const float grad_v1_z =
172
+ (bx * cy - by * cx) / 2.0 * inv_norm * grad_areas[f] +
173
+ (-by - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
174
+ grad_normals[3 * f + 0 ] +
175
+ (bx - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
176
+ grad_normals[3 * f + 1 ] +
177
+ -cz * (bx * cy - by * cx) * inv_norm_3 * grad_normals[3 * f + 2 ];
178
+ atomicAdd (grad_verts + 3 * i1 + 2 , grad_v1_z);
179
+
180
+ // grad v2 coming from grad areas
181
+ const float grad_v2_x =
182
+ (az * cy - ay * cz) / 2.0 * inv_norm * grad_areas[f] +
183
+ -cx * (az * cy - ay * cz) * inv_norm_3 * grad_normals[3 * f + 0 ] +
184
+ (az - cy * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
185
+ grad_normals[3 * f + 1 ] +
186
+ (-ay - cz * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
187
+ grad_normals[3 * f + 2 ];
188
+ atomicAdd (grad_verts + 3 * i2 + 0 , grad_v2_x);
189
+
190
+ const float grad_v2_y =
191
+ (ax * cz - az * cx) / 2.0 * inv_norm * grad_areas[f] +
192
+ (-az - cx * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
193
+ grad_normals[3 * f + 0 ] +
194
+ -cy * (ax * cz - az * cx) * inv_norm_3 * grad_normals[3 * f + 1 ] +
195
+ (ax - cz * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
196
+ grad_normals[3 * f + 2 ];
197
+ atomicAdd (grad_verts + 3 * i2 + 1 , grad_v2_y);
198
+
199
+ const float grad_v2_z =
200
+ (ay * cx - ax * cy) / 2.0 * inv_norm * grad_areas[f] +
201
+ (ay - cx * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
202
+ grad_normals[3 * f + 0 ] +
203
+ (-ax - cy * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
204
+ grad_normals[3 * f + 1 ] +
205
+ -cz * (ay * cx - ax * cy) * inv_norm_3 * grad_normals[3 * f + 2 ];
206
+ atomicAdd (grad_verts + 3 * i2 + 2 , grad_v2_z);
207
+ }
208
+ }
209
+
210
+ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda (
211
+ const at::Tensor verts,
212
+ const at::Tensor faces) {
61
213
const auto V = verts.size (0 );
62
214
const auto F = faces.size (0 );
63
215
@@ -66,16 +218,42 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
66
218
67
219
const int blocks = 64 ;
68
220
const int threads = 512 ;
69
- AT_DISPATCH_FLOATING_TYPES (verts. type (), " face_areas_normals_cuda " , ([&] {
70
- FaceAreasNormalsKernel< scalar_t >
71
- <<<blocks, threads>>> (
72
- verts.data_ptr <scalar_t >(),
73
- faces.data_ptr <long >(),
74
- areas.data_ptr <scalar_t >(),
75
- normals.data_ptr <scalar_t >(),
76
- V,
77
- F);
78
- }));
221
+ AT_DISPATCH_FLOATING_TYPES (
222
+ verts. type (), " face_areas_normals_forward_cuda " , ([&] {
223
+ FaceAreasNormalsForwardKernel< scalar_t > <<<blocks, threads>>> (
224
+ verts.data_ptr <scalar_t >(),
225
+ faces.data_ptr <int64_t >(),
226
+ areas.data_ptr <scalar_t >(),
227
+ normals.data_ptr <scalar_t >(),
228
+ V,
229
+ F);
230
+ }));
79
231
80
232
return std::make_tuple (areas, normals);
81
233
}
234
+
235
+ at::Tensor FaceAreasNormalsBackwardCuda (
236
+ const at::Tensor grad_areas,
237
+ const at::Tensor grad_normals,
238
+ const at::Tensor verts,
239
+ const at::Tensor faces) {
240
+ const auto V = verts.size (0 );
241
+ const auto F = faces.size (0 );
242
+
243
+ at::Tensor grad_verts = at::zeros ({V, 3 }, grad_areas.options ());
244
+
245
+ const int blocks = 64 ;
246
+ const int threads = 512 ;
247
+ // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
248
+ // doubles. Currently, support is for floats only.
249
+ FaceAreasNormalsBackwardKernel<<<blocks, threads>>> (
250
+ grad_areas.data_ptr <float >(),
251
+ grad_normals.data_ptr <float >(),
252
+ verts.data_ptr <float >(),
253
+ faces.data_ptr <int64_t >(),
254
+ grad_verts.data_ptr <float >(),
255
+ V,
256
+ F);
257
+
258
+ return grad_verts;
259
+ }
0 commit comments