Skip to content

Commit 26d2cc2

Browse files
jcjohnsonfacebook-github-bot
authored andcommitted
CUDA kernel for interpolate_face_attributes
Summary: When rendering meshes with Phong shading, interpolate_face_attributes was taking up a nontrivial fraction of the overall shading time. This diff replaces our Python implementation of this function with a CUDA implementation. Reviewed By: nikhilaravi Differential Revision: D21610763 fbshipit-source-id: 2bb362a28f698541812aeab539047264b125ebb8
1 parent 0505e5f commit 26d2cc2

File tree

11 files changed

+630
-140
lines changed

11 files changed

+630
-140
lines changed

pytorch3d/csrc/ext.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "compositing/weighted_sum.h"
77
#include "face_areas_normals/face_areas_normals.h"
88
#include "gather_scatter/gather_scatter.h"
9+
#include "interp_face_attrs/interp_face_attrs.h"
910
#include "knn/knn.h"
1011
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
1112
#include "point_mesh/point_mesh_edge.h"
@@ -18,6 +19,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1819
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
1920
m.def("packed_to_padded", &PackedToPadded);
2021
m.def("padded_to_packed", &PaddedToPacked);
22+
m.def("interp_face_attrs_forward", &InterpFaceAttrsForward);
23+
m.def("interp_face_attrs_backward", &InterpFaceAttrsBackward);
2124
#ifdef WITH_CUDA
2225
m.def("knn_check_version", &KnnCheckVersion);
2326
#endif
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include <c10/cuda/CUDAGuard.h>
6+
#include <tuple>
7+
8+
template <typename scalar_t>
9+
__global__ void InterpFaceAttrsForwardKernel(
10+
const int64_t* __restrict__ pix_to_face, // (P,)
11+
const scalar_t* __restrict__ barycentric_coords, // (P, 3)
12+
const scalar_t* __restrict__ face_attrs, // (F, 3, D)
13+
scalar_t* pix_attrs, // (P, D)
14+
const size_t P,
15+
const size_t F,
16+
const size_t D) {
17+
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
18+
const int num_threads = blockDim.x * gridDim.x;
19+
for (int pd = tid; pd < P * D; pd += num_threads) {
20+
const int p = pd / D;
21+
const int d = pd % D;
22+
const int64_t f = pix_to_face[p];
23+
if (f < 0) {
24+
continue;
25+
}
26+
scalar_t pix_attr = 0.0;
27+
for (int i = 0; i < 3; ++i) {
28+
scalar_t weight = barycentric_coords[p * 3 + i];
29+
scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d];
30+
pix_attr += weight * vert_attr;
31+
}
32+
pix_attrs[p * D + d] = pix_attr;
33+
}
34+
}
35+
36+
at::Tensor InterpFaceAttrsForwardCuda(
37+
const at::Tensor& pix_to_face,
38+
const at::Tensor& barycentric_coords,
39+
const at::Tensor& face_attrs) {
40+
// Make sure all inputs are on the same device
41+
at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1},
42+
barycentric_coords_t{barycentric_coords, "barycentric_coords", 2},
43+
face_attrs_t{face_attrs, "face_attributes", 3};
44+
at::CheckedFrom c = "InterpFaceAttrsForwardCuda";
45+
at::checkAllSameGPU(c, {pix_to_face_t, barycentric_coords_t, face_attrs_t});
46+
at::checkAllSameType(c, {barycentric_coords_t, face_attrs_t});
47+
48+
// Set the device for the kernel launch based on the input
49+
at::cuda::CUDAGuard device_guard(pix_to_face.device());
50+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
51+
52+
const auto P = pix_to_face.size(0);
53+
const auto F = face_attrs.size(0);
54+
const auto D = face_attrs.size(2);
55+
56+
TORCH_CHECK(
57+
barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3,
58+
"barycentric_coords must have size (P, 3)");
59+
TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)");
60+
61+
auto pix_attrs = at::zeros({P, D}, face_attrs.options());
62+
const int threads = 1024;
63+
const int blocks = 512;
64+
AT_DISPATCH_FLOATING_TYPES(
65+
face_attrs.scalar_type(), "interp_face_attrs_cuda", ([&] {
66+
InterpFaceAttrsForwardKernel<<<blocks, threads, 0, stream>>>(
67+
pix_to_face.contiguous().data_ptr<int64_t>(),
68+
barycentric_coords.contiguous().data_ptr<scalar_t>(),
69+
face_attrs.contiguous().data_ptr<scalar_t>(),
70+
pix_attrs.contiguous().data_ptr<scalar_t>(),
71+
P,
72+
F,
73+
D);
74+
}));
75+
AT_CUDA_CHECK(cudaGetLastError());
76+
return pix_attrs;
77+
}
78+
79+
template <typename scalar_t>
80+
__global__ void InterpFaceAttrsBackwardKernel(
81+
const int64_t* __restrict__ pix_to_face, // (P,)
82+
const scalar_t* __restrict__ barycentric_coords, // (P, 3)
83+
const scalar_t* __restrict__ face_attrs, // (F, 3, D)
84+
const scalar_t* __restrict__ grad_pix_attrs, // (P, D)
85+
scalar_t* __restrict__ grad_barycentric_coords, // (P, 3)
86+
scalar_t* __restrict__ grad_face_attrs, // (F, 3, D)
87+
const size_t P,
88+
const size_t F,
89+
const size_t D) {
90+
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
91+
const int num_threads = blockDim.x * gridDim.x;
92+
for (int pd = tid; pd < P * D; pd += num_threads) {
93+
const int p = pd / D;
94+
const int d = pd % D;
95+
const int64_t f = pix_to_face[p];
96+
if (f < 0) {
97+
continue;
98+
}
99+
scalar_t upstream_grad = grad_pix_attrs[p * D + d];
100+
for (int i = 0; i < 3; ++i) {
101+
scalar_t weight = barycentric_coords[p * 3 + i];
102+
scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d];
103+
scalar_t grad_bary_down = vert_attr * upstream_grad;
104+
scalar_t grad_face_down = weight * upstream_grad;
105+
atomicAdd(grad_barycentric_coords + p * 3 + i, grad_bary_down);
106+
atomicAdd(grad_face_attrs + f * 3 * D + i * D + d, grad_face_down);
107+
}
108+
}
109+
}
110+
111+
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
112+
const at::Tensor& pix_to_face,
113+
const at::Tensor& barycentric_coords,
114+
const at::Tensor& face_attrs,
115+
const at::Tensor& grad_pix_attrs) {
116+
// Make sure all inputs are on the same device
117+
at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1},
118+
barycentric_coords_t{barycentric_coords, "barycentric_coords", 2},
119+
face_attrs_t{face_attrs, "face_attributes", 3},
120+
grad_pix_attrs_t{grad_pix_attrs, "pix_attrs", 4};
121+
at::CheckedFrom c = "InterpFaceAttrsBackwarduda";
122+
at::checkAllSameGPU(
123+
c, {pix_to_face_t, barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
124+
at::checkAllSameType(
125+
c, {barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
126+
127+
// Set the device for the kernel launch based on the input
128+
at::cuda::CUDAGuard device_guard(pix_to_face.device());
129+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
130+
131+
const auto P = pix_to_face.size(0);
132+
const auto F = face_attrs.size(0);
133+
const auto D = face_attrs.size(2);
134+
135+
TORCH_CHECK(
136+
barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3,
137+
"barycentric_coords must have size (P, 3)");
138+
TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)");
139+
TORCH_CHECK(
140+
grad_pix_attrs.size(0) == P && grad_pix_attrs.size(1) == D,
141+
"grad_pix_attrs must have size (P, D)");
142+
143+
auto grad_barycentric_coords = at::zeros_like(barycentric_coords);
144+
auto grad_face_attrs = at::zeros_like(face_attrs);
145+
const int threads = 1024;
146+
const int blocks = 512;
147+
// Only allow float for now.
148+
// TODO: Add support for double once we fix atomicAdd
149+
// clang-format off
150+
InterpFaceAttrsBackwardKernel<<<blocks, threads, 0, stream>>>(
151+
pix_to_face.contiguous().data_ptr<int64_t>(),
152+
barycentric_coords.contiguous().data_ptr<float>(),
153+
face_attrs.contiguous().data_ptr<float>(),
154+
grad_pix_attrs.contiguous().data_ptr<float>(),
155+
grad_barycentric_coords.contiguous().data_ptr<float>(),
156+
grad_face_attrs.contiguous().data_ptr<float>(),
157+
P, F, D);
158+
AT_CUDA_CHECK(cudaGetLastError());
159+
// clang-format on
160+
return std::make_tuple(grad_barycentric_coords, grad_face_attrs);
161+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#pragma once
4+
#include <torch/extension.h>
5+
#include <tuple>
6+
#include "utils/pytorch3d_cutils.h"
7+
8+
// Interpolates per-face attributes (forward pass)
9+
//
10+
// Inputs:
11+
// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel.
12+
// Each element should be < F, the total number of faces.
13+
// Face indices < 0 indicate that the pixel is not covered by a face.
14+
// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords.
15+
// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional
16+
// value for each vertex of each face.
17+
//
18+
// Returns:
19+
// pix_attributes: FloatTensor of shape (P, D) giving an interpolated value
20+
// for each pixel.
21+
22+
// CPU implementation
23+
at::Tensor InterpFaceAttrsForwardCpu(
24+
const at::Tensor& pix_to_face,
25+
const at::Tensor& barycentric_coords,
26+
const at::Tensor& face_attrs) {
27+
AT_ERROR("Not Implemented");
28+
return pix_to_face;
29+
}
30+
31+
#ifdef WITH_CUDA
32+
// Cuda implementation.
33+
at::Tensor InterpFaceAttrsForwardCuda(
34+
const at::Tensor& pix_to_face,
35+
const at::Tensor& barycentric_coords,
36+
const at::Tensor& face_attrs);
37+
#endif
38+
39+
// General implementation
40+
at::Tensor InterpFaceAttrsForward(
41+
const at::Tensor& pix_to_face,
42+
const at::Tensor& barycentric_coords,
43+
const at::Tensor& face_attrs) {
44+
if (pix_to_face.is_cuda()) {
45+
#ifdef WITH_CUDA
46+
CHECK_CUDA(face_attrs);
47+
CHECK_CUDA(barycentric_coords);
48+
return InterpFaceAttrsForwardCuda(
49+
pix_to_face, barycentric_coords, face_attrs);
50+
#else
51+
AT_ERROR("Not compiled with GPU support.");
52+
#endif
53+
}
54+
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
55+
}
56+
57+
// Interpolates per-face attributes (backward pass)
58+
//
59+
// Inputs:
60+
// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel.
61+
// Each element should be < F, the total number of faces.
62+
// Face indices < 0 indicate that the pixel is not covered by a face.
63+
// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords.
64+
// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional
65+
// value for each vertex of each face.
66+
// grad_pix_attrs: Upstream gradients of shape (P, D)
67+
//
68+
// Returns a tuple of:
69+
// grad_barycentric_coords: FloatTensor of shape (P, 3)
70+
// grad_face_attrs: FloatTensor of shape (F, 3, D)
71+
72+
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCpu(
73+
const at::Tensor& pix_to_face,
74+
const at::Tensor& barycentric_coords,
75+
const at::Tensor& face_attrs,
76+
const at::Tensor& grad_pix_attrs) {
77+
AT_ERROR("Not Implemented");
78+
return std::make_tuple(pix_to_face, pix_to_face);
79+
}
80+
81+
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
82+
const at::Tensor& pix_to_face,
83+
const at::Tensor& barycentric_coords,
84+
const at::Tensor& face_attrs,
85+
const at::Tensor& grad_pix_attrs);
86+
87+
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
88+
const at::Tensor& pix_to_face,
89+
const at::Tensor& barycentric_coords,
90+
const at::Tensor& face_attrs,
91+
const at::Tensor& grad_pix_attrs) {
92+
if (pix_to_face.is_cuda()) {
93+
#ifdef WITH_CUDA
94+
CHECK_CUDA(face_attrs);
95+
CHECK_CUDA(barycentric_coords);
96+
CHECK_CUDA(grad_pix_attrs);
97+
return InterpFaceAttrsBackwardCuda(
98+
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
99+
#else
100+
AT_ERROR("Not compiled with GPU support.");
101+
#endif
102+
}
103+
return InterpFaceAttrsBackwardCpu(
104+
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
105+
}

pytorch3d/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .cubify import cubify
55
from .graph_conv import GraphConv
6+
from .interp_face_attrs import interpolate_face_attributes
67
from .knn import knn_gather, knn_points
78
from .mesh_face_areas_normals import mesh_face_areas_normals
89
from .packed_to_padded import packed_to_padded, padded_to_packed

pytorch3d/ops/interp_face_attrs.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
import torch
4+
from pytorch3d import _C
5+
from torch.autograd import Function
6+
from torch.autograd.function import once_differentiable
7+
8+
9+
def interpolate_face_attributes(
10+
pix_to_face: torch.Tensor,
11+
barycentric_coords: torch.Tensor,
12+
face_attributes: torch.Tensor,
13+
) -> torch.Tensor:
14+
"""
15+
Interpolate arbitrary face attributes using the barycentric coordinates
16+
for each pixel in the rasterized output.
17+
18+
Args:
19+
pix_to_face: LongTensor of shape (...) specifying the indices
20+
of the faces (in the packed representation) which overlap each
21+
pixel in the image. A value < 0 indicates that the pixel does not
22+
overlap any face and should be skipped.
23+
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
24+
the barycentric coordianates of each pixel
25+
relative to the faces (in the packed
26+
representation) which overlap the pixel.
27+
face_attributes: packed attributes of shape (total_faces, 3, D),
28+
specifying the value of the attribute for each
29+
vertex in the face.
30+
31+
Returns:
32+
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
33+
value of the face attribute for each pixel.
34+
"""
35+
# Check shapes
36+
F, FV, D = face_attributes.shape
37+
if FV != 3:
38+
raise ValueError("Faces can only have three vertices; got %r" % FV)
39+
N, H, W, K, _ = barycentric_coords.shape
40+
if pix_to_face.shape != (N, H, W, K):
41+
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
42+
raise ValueError(msg % (pix_to_face.shape,))
43+
44+
# On CPU use the python version
45+
# TODO: Implement a C++ version of this function
46+
if not pix_to_face.is_cuda:
47+
args = (pix_to_face, barycentric_coords, face_attributes)
48+
return interpolate_face_attributes_python(*args)
49+
50+
# Otherwise flatten and call the custom autograd function
51+
N, H, W, K = pix_to_face.shape
52+
pix_to_face = pix_to_face.view(-1)
53+
barycentric_coords = barycentric_coords.view(N * H * W * K, 3)
54+
args = (pix_to_face, barycentric_coords, face_attributes)
55+
out = _InterpFaceAttrs.apply(*args)
56+
out = out.view(N, H, W, K, -1)
57+
return out
58+
59+
60+
class _InterpFaceAttrs(Function):
61+
@staticmethod
62+
def forward(ctx, pix_to_face, barycentric_coords, face_attrs):
63+
args = (pix_to_face, barycentric_coords, face_attrs)
64+
ctx.save_for_backward(*args)
65+
return _C.interp_face_attrs_forward(*args)
66+
67+
@staticmethod
68+
@once_differentiable
69+
def backward(ctx, grad_pix_attrs):
70+
args = ctx.saved_tensors
71+
args = args + (grad_pix_attrs,)
72+
grads = _C.interp_face_attrs_backward(*args)
73+
grad_pix_to_face = None
74+
grad_barycentric_coords = grads[0]
75+
grad_face_attrs = grads[1]
76+
return grad_pix_to_face, grad_barycentric_coords, grad_face_attrs
77+
78+
79+
def interpolate_face_attributes_python(
80+
pix_to_face: torch.Tensor,
81+
barycentric_coords: torch.Tensor,
82+
face_attributes: torch.Tensor,
83+
) -> torch.Tensor:
84+
F, FV, D = face_attributes.shape
85+
N, H, W, K, _ = barycentric_coords.shape
86+
87+
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
88+
mask = pix_to_face < 0
89+
pix_to_face = pix_to_face.clone()
90+
pix_to_face[mask] = 0
91+
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
92+
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
93+
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
94+
pixel_vals[mask] = 0 # Replace masked values in output.
95+
return pixel_vals

pytorch3d/renderer/mesh/shading.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from typing import Tuple
55

66
import torch
7-
8-
from .texturing import interpolate_face_attributes
7+
from pytorch3d.ops import interpolate_face_attributes
98

109

1110
def _apply_lighting(

0 commit comments

Comments
 (0)