Skip to content

Commit 5359977

Browse files
oawilesfacebook-github-bot
authored andcommitted
Accumulate points (#4)
Summary: Code for accumulating points in the z-buffer in three ways: 1. weighted sum 2. normalised weighted sum 3. alpha compositing Pull Request resolved: fairinternal/pytorch3d#4 Reviewed By: nikhilaravi Differential Revision: D20522422 Pulled By: gkioxari fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
1 parent 5218f45 commit 5359977

21 files changed

+2466
-4
lines changed

docs/tutorials/deform_source_mesh_to_target_mesh.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -673,9 +673,9 @@
673673
"provenance": []
674674
},
675675
"kernelspec": {
676-
"display_name": "pytorch3d (local)",
676+
"display_name": "Python 3",
677677
"language": "python",
678-
"name": "pytorch3d_local"
678+
"name": "python3"
679679
},
680680
"language_info": {
681681
"codemirror_mode": {
@@ -687,7 +687,7 @@
687687
"name": "python",
688688
"nbconvert_exporter": "python",
689689
"pygments_lexer": "ipython3",
690-
"version": "3.7.5+"
690+
"version": "3.7.6"
691691
},
692692
"widgets": {
693693
"application/vnd.jupyter.widget-state+json": {

docs/tutorials/render_coloured_points.ipynb

Lines changed: 303 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <torch/extension.h>
4+
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
#include <stdio.h>
9+
#include <vector>
10+
11+
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
12+
// Currently, support is for floats only.
13+
__global__ void alphaCompositeCudaForwardKernel(
14+
// clang-format off
15+
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
16+
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
17+
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
18+
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
19+
// clang-format on
20+
const int64_t batch_size = result.size(0);
21+
const int64_t C = features.size(0);
22+
const int64_t H = points_idx.size(2);
23+
const int64_t W = points_idx.size(3);
24+
25+
// Get the batch and index
26+
const int batch = blockIdx.x;
27+
28+
const int num_pixels = C * W * H;
29+
const int num_threads = gridDim.y * blockDim.x;
30+
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
31+
32+
// Iterate over each feature in each pixel
33+
for (int pid = tid; pid < num_pixels; pid += num_threads) {
34+
int ch = pid / (W * H);
35+
int j = (pid % (W * H)) / H;
36+
int i = (pid % (W * H)) % H;
37+
38+
// alphacomposite the different values
39+
float cum_alpha = 1.;
40+
// Iterate through the closest K points for this pixel
41+
for (int k = 0; k < points_idx.size(1); ++k) {
42+
int n_idx = points_idx[batch][k][j][i];
43+
44+
// Sentinel value is -1 indicating no point overlaps the pixel
45+
if (n_idx < 0) {
46+
continue;
47+
}
48+
49+
float alpha = alphas[batch][k][j][i];
50+
// TODO(gkioxari) It might be more efficient to have threads write in a
51+
// local variable, and move atomicAdd outside of the loop such that
52+
// atomicAdd is executed once per thread.
53+
atomicAdd(
54+
&result[batch][ch][j][i], features[ch][n_idx] * cum_alpha * alpha);
55+
cum_alpha = cum_alpha * (1 - alpha);
56+
}
57+
}
58+
}
59+
60+
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
61+
// Currently, support is for floats only.
62+
__global__ void alphaCompositeCudaBackwardKernel(
63+
// clang-format off
64+
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
65+
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
66+
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
67+
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
68+
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
69+
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
70+
// clang-format on
71+
const int64_t batch_size = points_idx.size(0);
72+
const int64_t C = features.size(0);
73+
const int64_t H = points_idx.size(2);
74+
const int64_t W = points_idx.size(3);
75+
76+
// Get the batch and index
77+
const int batch = blockIdx.x;
78+
79+
const int num_pixels = C * W * H;
80+
const int num_threads = gridDim.y * blockDim.x;
81+
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
82+
83+
// Parallelize over each feature in each pixel in images of size H * W,
84+
// for each image in the batch of size batch_size
85+
for (int pid = tid; pid < num_pixels; pid += num_threads) {
86+
int ch = pid / (W * H);
87+
int j = (pid % (W * H)) / H;
88+
int i = (pid % (W * H)) % H;
89+
90+
// alphacomposite the different values
91+
float cum_alpha = 1.;
92+
// Iterate through the closest K points for this pixel
93+
for (int k = 0; k < points_idx.size(1); ++k) {
94+
int n_idx = points_idx[batch][k][j][i];
95+
96+
// Sentinel value is -1 indicating no point overlaps the pixel
97+
if (n_idx < 0) {
98+
continue;
99+
}
100+
float alpha = alphas[batch][k][j][i];
101+
102+
// TODO(gkioxari) It might be more efficient to have threads write in a
103+
// local variable, and move atomicAdd outside of the loop such that
104+
// atomicAdd is executed once per thread.
105+
atomicAdd(
106+
&grad_alphas[batch][k][j][i],
107+
cum_alpha * features[ch][n_idx] * grad_outputs[batch][ch][j][i]);
108+
atomicAdd(
109+
&grad_features[ch][n_idx],
110+
cum_alpha * alpha * grad_outputs[batch][ch][j][i]);
111+
112+
// Iterate over all (K-1) nearest points to update gradient
113+
for (int t = 0; t < k; ++t) {
114+
int t_idx = points_idx[batch][t][j][i];
115+
// Sentinel value is -1, indicating no point overlaps this pixel
116+
if (t_idx < 0) {
117+
continue;
118+
}
119+
float alpha_tvalue = alphas[batch][t][j][i];
120+
// TODO(gkioxari) It might be more efficient to have threads write in a
121+
// local variable, and move atomicAdd outside of the loop such that
122+
// atomicAdd is executed once per thread.
123+
atomicAdd(
124+
&grad_alphas[batch][t][j][i],
125+
-grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha *
126+
alpha / (1 - alpha_tvalue));
127+
}
128+
129+
cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]);
130+
}
131+
}
132+
}
133+
134+
torch::Tensor alphaCompositeCudaForward(
135+
const torch::Tensor& features,
136+
const torch::Tensor& alphas,
137+
const torch::Tensor& points_idx) {
138+
const int64_t batch_size = points_idx.size(0);
139+
const int64_t C = features.size(0);
140+
const int64_t H = points_idx.size(2);
141+
const int64_t W = points_idx.size(3);
142+
143+
auto result = torch::zeros({batch_size, C, H, W}, features.options());
144+
145+
const dim3 threadsPerBlock(64);
146+
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
147+
148+
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
149+
// doubles. Currently, support is for floats only.
150+
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
151+
// clang-format off
152+
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
153+
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
154+
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
155+
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
156+
// clang-format on
157+
158+
return result;
159+
}
160+
161+
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
162+
const torch::Tensor& grad_outputs,
163+
const torch::Tensor& features,
164+
const torch::Tensor& alphas,
165+
const torch::Tensor& points_idx) {
166+
auto grad_features = torch::zeros_like(features);
167+
auto grad_alphas = torch::zeros_like(alphas);
168+
169+
const int64_t bs = alphas.size(0);
170+
171+
const dim3 threadsPerBlock(64);
172+
const dim3 numBlocks(bs, 1024 / bs + 1);
173+
174+
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
175+
// doubles. Currently, support is for floats only.
176+
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
177+
// clang-format off
178+
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
179+
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
180+
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
181+
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
182+
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
183+
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
184+
// clang-format on
185+
186+
return std::make_tuple(grad_features, grad_alphas);
187+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <torch/extension.h>
4+
#include "pytorch3d_cutils.h"
5+
6+
#include <vector>
7+
8+
// Perform alpha compositing of points in a z-buffer.
9+
//
10+
// Inputs:
11+
// features: FloatTensor of shape (C, P) which gives the features
12+
// of each point where C is the size of the feature and
13+
// P the number of points.
14+
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
15+
// points_per_pixel is the number of points in the z-buffer
16+
// sorted in z-order, and W is the image size.
17+
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
18+
// indices of the nearest points at each pixel, sorted in z-order.
19+
// Returns:
20+
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
21+
// feature for each point. Concretely, it gives:
22+
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
23+
// features[c,points_idx[b,k,i,j]]
24+
// where cum_alpha_k =
25+
// alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
26+
27+
// CUDA declarations
28+
#ifdef WITH_CUDA
29+
torch::Tensor alphaCompositeCudaForward(
30+
const torch::Tensor& features,
31+
const torch::Tensor& alphas,
32+
const torch::Tensor& points_idx);
33+
34+
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
35+
const torch::Tensor& grad_outputs,
36+
const torch::Tensor& features,
37+
const torch::Tensor& alphas,
38+
const torch::Tensor& points_idx);
39+
#endif
40+
41+
// C++ declarations
42+
torch::Tensor alphaCompositeCpuForward(
43+
const torch::Tensor& features,
44+
const torch::Tensor& alphas,
45+
const torch::Tensor& points_idx);
46+
47+
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward(
48+
const torch::Tensor& grad_outputs,
49+
const torch::Tensor& features,
50+
const torch::Tensor& alphas,
51+
const torch::Tensor& points_idx);
52+
53+
torch::Tensor alphaCompositeForward(
54+
torch::Tensor& features,
55+
torch::Tensor& alphas,
56+
torch::Tensor& points_idx) {
57+
features = features.contiguous();
58+
alphas = alphas.contiguous();
59+
points_idx = points_idx.contiguous();
60+
61+
if (features.type().is_cuda()) {
62+
#ifdef WITH_CUDA
63+
CHECK_CONTIGUOUS_CUDA(features);
64+
CHECK_CONTIGUOUS_CUDA(alphas);
65+
CHECK_CONTIGUOUS_CUDA(points_idx);
66+
#else
67+
AT_ERROR("Not compiled with GPU support");
68+
#endif
69+
return alphaCompositeCudaForward(features, alphas, points_idx);
70+
} else {
71+
CHECK_CONTIGUOUS(features);
72+
CHECK_CONTIGUOUS(alphas);
73+
CHECK_CONTIGUOUS(points_idx);
74+
75+
return alphaCompositeCpuForward(features, alphas, points_idx);
76+
}
77+
}
78+
79+
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
80+
torch::Tensor& grad_outputs,
81+
torch::Tensor& features,
82+
torch::Tensor& alphas,
83+
torch::Tensor& points_idx) {
84+
grad_outputs = grad_outputs.contiguous();
85+
features = features.contiguous();
86+
alphas = alphas.contiguous();
87+
points_idx = points_idx.contiguous();
88+
89+
if (grad_outputs.type().is_cuda()) {
90+
#ifdef WITH_CUDA
91+
CHECK_CONTIGUOUS_CUDA(grad_outputs);
92+
CHECK_CONTIGUOUS_CUDA(features);
93+
CHECK_CONTIGUOUS_CUDA(alphas);
94+
CHECK_CONTIGUOUS_CUDA(points_idx);
95+
#else
96+
AT_ERROR("Not compiled with GPU support");
97+
#endif
98+
99+
return alphaCompositeCudaBackward(
100+
grad_outputs, features, alphas, points_idx);
101+
} else {
102+
CHECK_CONTIGUOUS(grad_outputs);
103+
CHECK_CONTIGUOUS(features);
104+
CHECK_CONTIGUOUS(alphas);
105+
CHECK_CONTIGUOUS(points_idx);
106+
107+
return alphaCompositeCpuBackward(
108+
grad_outputs, features, alphas, points_idx);
109+
}
110+
}

0 commit comments

Comments
 (0)