|
| 1 | +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | + |
| 3 | +#include <torch/extension.h> |
| 4 | + |
| 5 | +#include <cuda.h> |
| 6 | +#include <cuda_runtime.h> |
| 7 | + |
| 8 | +#include <stdio.h> |
| 9 | +#include <vector> |
| 10 | + |
| 11 | +// TODO(gkioxari) support all data types once AtomicAdd supports doubles. |
| 12 | +// Currently, support is for floats only. |
| 13 | +__global__ void alphaCompositeCudaForwardKernel( |
| 14 | + // clang-format off |
| 15 | + torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result, |
| 16 | + const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features, |
| 17 | + const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas, |
| 18 | + const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) { |
| 19 | + // clang-format on |
| 20 | + const int64_t batch_size = result.size(0); |
| 21 | + const int64_t C = features.size(0); |
| 22 | + const int64_t H = points_idx.size(2); |
| 23 | + const int64_t W = points_idx.size(3); |
| 24 | + |
| 25 | + // Get the batch and index |
| 26 | + const int batch = blockIdx.x; |
| 27 | + |
| 28 | + const int num_pixels = C * W * H; |
| 29 | + const int num_threads = gridDim.y * blockDim.x; |
| 30 | + const int tid = blockIdx.y * blockDim.x + threadIdx.x; |
| 31 | + |
| 32 | + // Iterate over each feature in each pixel |
| 33 | + for (int pid = tid; pid < num_pixels; pid += num_threads) { |
| 34 | + int ch = pid / (W * H); |
| 35 | + int j = (pid % (W * H)) / H; |
| 36 | + int i = (pid % (W * H)) % H; |
| 37 | + |
| 38 | + // alphacomposite the different values |
| 39 | + float cum_alpha = 1.; |
| 40 | + // Iterate through the closest K points for this pixel |
| 41 | + for (int k = 0; k < points_idx.size(1); ++k) { |
| 42 | + int n_idx = points_idx[batch][k][j][i]; |
| 43 | + |
| 44 | + // Sentinel value is -1 indicating no point overlaps the pixel |
| 45 | + if (n_idx < 0) { |
| 46 | + continue; |
| 47 | + } |
| 48 | + |
| 49 | + float alpha = alphas[batch][k][j][i]; |
| 50 | + // TODO(gkioxari) It might be more efficient to have threads write in a |
| 51 | + // local variable, and move atomicAdd outside of the loop such that |
| 52 | + // atomicAdd is executed once per thread. |
| 53 | + atomicAdd( |
| 54 | + &result[batch][ch][j][i], features[ch][n_idx] * cum_alpha * alpha); |
| 55 | + cum_alpha = cum_alpha * (1 - alpha); |
| 56 | + } |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +// TODO(gkioxari) support all data types once AtomicAdd supports doubles. |
| 61 | +// Currently, support is for floats only. |
| 62 | +__global__ void alphaCompositeCudaBackwardKernel( |
| 63 | + // clang-format off |
| 64 | + torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features, |
| 65 | + torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas, |
| 66 | + const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs, |
| 67 | + const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features, |
| 68 | + const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas, |
| 69 | + const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) { |
| 70 | + // clang-format on |
| 71 | + const int64_t batch_size = points_idx.size(0); |
| 72 | + const int64_t C = features.size(0); |
| 73 | + const int64_t H = points_idx.size(2); |
| 74 | + const int64_t W = points_idx.size(3); |
| 75 | + |
| 76 | + // Get the batch and index |
| 77 | + const int batch = blockIdx.x; |
| 78 | + |
| 79 | + const int num_pixels = C * W * H; |
| 80 | + const int num_threads = gridDim.y * blockDim.x; |
| 81 | + const int tid = blockIdx.y * blockDim.x + threadIdx.x; |
| 82 | + |
| 83 | + // Parallelize over each feature in each pixel in images of size H * W, |
| 84 | + // for each image in the batch of size batch_size |
| 85 | + for (int pid = tid; pid < num_pixels; pid += num_threads) { |
| 86 | + int ch = pid / (W * H); |
| 87 | + int j = (pid % (W * H)) / H; |
| 88 | + int i = (pid % (W * H)) % H; |
| 89 | + |
| 90 | + // alphacomposite the different values |
| 91 | + float cum_alpha = 1.; |
| 92 | + // Iterate through the closest K points for this pixel |
| 93 | + for (int k = 0; k < points_idx.size(1); ++k) { |
| 94 | + int n_idx = points_idx[batch][k][j][i]; |
| 95 | + |
| 96 | + // Sentinel value is -1 indicating no point overlaps the pixel |
| 97 | + if (n_idx < 0) { |
| 98 | + continue; |
| 99 | + } |
| 100 | + float alpha = alphas[batch][k][j][i]; |
| 101 | + |
| 102 | + // TODO(gkioxari) It might be more efficient to have threads write in a |
| 103 | + // local variable, and move atomicAdd outside of the loop such that |
| 104 | + // atomicAdd is executed once per thread. |
| 105 | + atomicAdd( |
| 106 | + &grad_alphas[batch][k][j][i], |
| 107 | + cum_alpha * features[ch][n_idx] * grad_outputs[batch][ch][j][i]); |
| 108 | + atomicAdd( |
| 109 | + &grad_features[ch][n_idx], |
| 110 | + cum_alpha * alpha * grad_outputs[batch][ch][j][i]); |
| 111 | + |
| 112 | + // Iterate over all (K-1) nearest points to update gradient |
| 113 | + for (int t = 0; t < k; ++t) { |
| 114 | + int t_idx = points_idx[batch][t][j][i]; |
| 115 | + // Sentinel value is -1, indicating no point overlaps this pixel |
| 116 | + if (t_idx < 0) { |
| 117 | + continue; |
| 118 | + } |
| 119 | + float alpha_tvalue = alphas[batch][t][j][i]; |
| 120 | + // TODO(gkioxari) It might be more efficient to have threads write in a |
| 121 | + // local variable, and move atomicAdd outside of the loop such that |
| 122 | + // atomicAdd is executed once per thread. |
| 123 | + atomicAdd( |
| 124 | + &grad_alphas[batch][t][j][i], |
| 125 | + -grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha * |
| 126 | + alpha / (1 - alpha_tvalue)); |
| 127 | + } |
| 128 | + |
| 129 | + cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]); |
| 130 | + } |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +torch::Tensor alphaCompositeCudaForward( |
| 135 | + const torch::Tensor& features, |
| 136 | + const torch::Tensor& alphas, |
| 137 | + const torch::Tensor& points_idx) { |
| 138 | + const int64_t batch_size = points_idx.size(0); |
| 139 | + const int64_t C = features.size(0); |
| 140 | + const int64_t H = points_idx.size(2); |
| 141 | + const int64_t W = points_idx.size(3); |
| 142 | + |
| 143 | + auto result = torch::zeros({batch_size, C, H, W}, features.options()); |
| 144 | + |
| 145 | + const dim3 threadsPerBlock(64); |
| 146 | + const dim3 numBlocks(batch_size, 1024 / batch_size + 1); |
| 147 | + |
| 148 | + // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports |
| 149 | + // doubles. Currently, support is for floats only. |
| 150 | + alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>( |
| 151 | + // clang-format off |
| 152 | + result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), |
| 153 | + features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(), |
| 154 | + alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), |
| 155 | + points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>()); |
| 156 | + // clang-format on |
| 157 | + |
| 158 | + return result; |
| 159 | +} |
| 160 | + |
| 161 | +std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward( |
| 162 | + const torch::Tensor& grad_outputs, |
| 163 | + const torch::Tensor& features, |
| 164 | + const torch::Tensor& alphas, |
| 165 | + const torch::Tensor& points_idx) { |
| 166 | + auto grad_features = torch::zeros_like(features); |
| 167 | + auto grad_alphas = torch::zeros_like(alphas); |
| 168 | + |
| 169 | + const int64_t bs = alphas.size(0); |
| 170 | + |
| 171 | + const dim3 threadsPerBlock(64); |
| 172 | + const dim3 numBlocks(bs, 1024 / bs + 1); |
| 173 | + |
| 174 | + // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports |
| 175 | + // doubles. Currently, support is for floats only. |
| 176 | + alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>( |
| 177 | + // clang-format off |
| 178 | + grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(), |
| 179 | + grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), |
| 180 | + grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), |
| 181 | + features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(), |
| 182 | + alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), |
| 183 | + points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>()); |
| 184 | + // clang-format on |
| 185 | + |
| 186 | + return std::make_tuple(grad_features, grad_alphas); |
| 187 | +} |
0 commit comments