Skip to content

Commit ff19c64

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Barycentric clipping in the renderer and flat shading
Summary: Updates to the Renderer to enable barycentric clipping. This is important when there is blurring in the rasterization step. Also added support for flat shading. Reviewed By: jcjohnson Differential Revision: D19934259 fbshipit-source-id: 036e48636cd80d28a04405d7a29fcc71a2982904
1 parent f358b9b commit ff19c64

14 files changed

+254
-108
lines changed

pytorch3d/renderer/blending.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
9090
return torch.flip(pixel_colors, [1])
9191

9292

93-
def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
93+
def softmax_rgb_blend(
94+
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
95+
) -> torch.Tensor:
9496
"""
9597
RGB and alpha channel blending to return an RGBA image based on the method
9698
proposed in [0]
@@ -118,13 +120,16 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
118120
exponential function used to control the opacity of the color.
119121
- background_color: (3) element list/tuple/torch.Tensor specifying
120122
the RGB values for the background color.
123+
znear: float, near clipping plane in the z direction
124+
zfar: float, far clipping plane in the z direction
121125
122126
Returns:
123127
RGBA pixel_colors: (N, H, W, 4)
124128
125129
[0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
126130
Image-based 3D Reasoning'
127131
"""
132+
128133
N, H, W, K = fragments.pix_to_face.shape
129134
device = fragments.pix_to_face.device
130135
pix_colors = torch.ones(
@@ -140,11 +145,6 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
140145
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
141146
delta = torch.tensor(delta, device=device)
142147

143-
# Near and far clipping planes.
144-
# TODO: add zfar/znear as input params.
145-
zfar = 100.0
146-
znear = 1.0
147-
148148
# Mask for padded pixels.
149149
mask = fragments.pix_to_face >= 0
150150

@@ -164,6 +164,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
164164
# Weights for each face. Adjust the exponential by the max z to prevent
165165
# overflow. zbuf shape (N, H, W, K), find max over K.
166166
# TODO: there may still be some instability in the exponent calculation.
167+
167168
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
168169
z_inv_max = torch.max(z_inv, dim=-1).values[..., None]
169170
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)

pytorch3d/renderer/mesh/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
4+
from .texturing import ( # isort:skip
5+
interpolate_texture_map,
6+
interpolate_vertex_colors,
7+
)
38
from .rasterize_meshes import rasterize_meshes
49
from .rasterizer import MeshRasterizer, RasterizationSettings
510
from .renderer import MeshRenderer
@@ -13,10 +18,6 @@
1318
TexturedSoftPhongShader,
1419
)
1520
from .shading import gouraud_shading, phong_shading
16-
from .texturing import ( # isort: skip
17-
interpolate_face_attributes,
18-
interpolate_texture_map,
19-
interpolate_vertex_colors,
20-
)
21+
from .utils import interpolate_face_attributes
2122

2223
__all__ = [k for k in globals().keys() if not k.startswith("_")]

pytorch3d/renderer/mesh/renderer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import torch
66
import torch.nn as nn
77

8+
from .rasterizer import Fragments
9+
from .utils import _clip_barycentric_coordinates, _interpolate_zbuf
10+
811
# A renderer class should be initialized with a
912
# function for rasterization and a function for shading.
1013
# The rasterizer should:
@@ -34,6 +37,34 @@ def __init__(self, rasterizer, shader):
3437
self.shader = shader
3538

3639
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
40+
"""
41+
Render a batch of images from a batch of meshes by rasterizing and then shading.
42+
43+
NOTE: If the blur radius for rasterization is > 0.0, some pixels can have one or
44+
more barycentric coordinates lying outside the range [0, 1]. For a pixel with
45+
out of bounds barycentric coordinates with respect to a face f, clipping is required
46+
before interpolating the texture uv coordinates and z buffer so that the colors and
47+
depths are limited to the range for the corresponding face.
48+
"""
3749
fragments = self.rasterizer(meshes_world, **kwargs)
50+
raster_settings = kwargs.get(
51+
"raster_settings", self.rasterizer.raster_settings
52+
)
53+
if raster_settings.blur_radius > 0.0:
54+
# TODO: potentially move barycentric clipping to the rasterizer
55+
# if no downstream functions requires unclipped values.
56+
# This will avoid unnecssary re-interpolation of the z buffer.
57+
clipped_bary_coords = _clip_barycentric_coordinates(
58+
fragments.bary_coords
59+
)
60+
clipped_zbuf = _interpolate_zbuf(
61+
fragments.pix_to_face, clipped_bary_coords, meshes_world
62+
)
63+
fragments = Fragments(
64+
bary_coords=clipped_bary_coords,
65+
zbuf=clipped_zbuf,
66+
dists=fragments.dists,
67+
pix_to_face=fragments.pix_to_face,
68+
)
3869
images = self.shader(fragments, meshes_world, **kwargs)
3970
return images

pytorch3d/renderer/mesh/shader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
270270
cameras = kwargs.get("cameras", self.cameras)
271271
lights = kwargs.get("lights", self.lights)
272272
materials = kwargs.get("materials", self.materials)
273+
blend_params = kwargs.get("blend_params", self.blend_params)
273274
colors = phong_shading(
274275
meshes=meshes,
275276
fragments=fragments,
@@ -278,7 +279,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
278279
cameras=cameras,
279280
materials=materials,
280281
)
281-
images = softmax_rgb_blend(colors, fragments, self.blend_params)
282+
images = softmax_rgb_blend(colors, fragments, blend_params)
282283
return images
283284

284285

pytorch3d/renderer/mesh/shading.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ def phong_shading(
7070
vertex_normals = meshes.verts_normals_packed() # (V, 3)
7171
faces_verts = verts[faces]
7272
faces_normals = vertex_normals[faces]
73-
pixel_coords = interpolate_face_attributes(fragments, faces_verts)
74-
pixel_normals = interpolate_face_attributes(fragments, faces_normals)
73+
pixel_coords = interpolate_face_attributes(
74+
fragments.pix_to_face, fragments.bary_coords, faces_verts
75+
)
76+
pixel_normals = interpolate_face_attributes(
77+
fragments.pix_to_face, fragments.bary_coords, faces_normals
78+
)
7579
ambient, diffuse, specular = _apply_lighting(
7680
pixel_coords, pixel_normals, lights, cameras, materials
7781
)
@@ -122,7 +126,9 @@ def gouraud_shading(
122126
)
123127
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
124128
face_colors = verts_colors_shaded[faces]
125-
colors = interpolate_face_attributes(fragments, face_colors)
129+
colors = interpolate_face_attributes(
130+
fragments.pix_to_face, fragments.bary_coords, face_colors
131+
)
126132
return colors
127133

128134

pytorch3d/renderer/mesh/texturing.py

Lines changed: 9 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,75 +7,7 @@
77

88
from pytorch3d.structures.textures import Textures
99

10-
11-
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
12-
"""
13-
Args:
14-
bary: barycentric coordinates of shape (...., 3) where `...` represents
15-
an arbitrary number of dimensions
16-
17-
Returns:
18-
bary: All barycentric coordinate values clipped to the range [0, 1]
19-
and renormalized. The output is the same shape as the input.
20-
"""
21-
if bary.shape[-1] != 3:
22-
msg = "Expected barycentric coords to have last dim = 3; got %r"
23-
raise ValueError(msg % bary.shape)
24-
clipped = bary.clamp(min=0, max=1)
25-
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
26-
clipped = clipped / clipped_sum
27-
return clipped
28-
29-
30-
def interpolate_face_attributes(
31-
fragments, face_attributes: torch.Tensor, bary_clip: bool = False
32-
) -> torch.Tensor:
33-
"""
34-
Interpolate arbitrary face attributes using the barycentric coordinates
35-
for each pixel in the rasterized output.
36-
37-
Args:
38-
fragments:
39-
The outputs of rasterization. From this we use
40-
41-
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
42-
of the faces (in the packed representation) which
43-
overlap each pixel in the image.
44-
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
45-
the barycentric coordianates of each pixel
46-
relative to the faces (in the packed
47-
representation) which overlap the pixel.
48-
face_attributes: packed attributes of shape (total_faces, 3, D),
49-
specifying the value of the attribute for each
50-
vertex in the face.
51-
bary_clip: Bool to indicate if barycentric_coords should be clipped
52-
before being used for interpolation.
53-
54-
Returns:
55-
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
56-
value of the face attribute for each pixel.
57-
"""
58-
pix_to_face = fragments.pix_to_face
59-
barycentric_coords = fragments.bary_coords
60-
F, FV, D = face_attributes.shape
61-
if FV != 3:
62-
raise ValueError("Faces can only have three vertices; got %r" % FV)
63-
N, H, W, K, _ = barycentric_coords.shape
64-
if pix_to_face.shape != (N, H, W, K):
65-
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
66-
raise ValueError(msg % pix_to_face.shape)
67-
if bary_clip:
68-
barycentric_coords = _clip_barycentric_coordinates(barycentric_coords)
69-
70-
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
71-
mask = pix_to_face == -1
72-
pix_to_face = pix_to_face.clone()
73-
pix_to_face[mask] = 0
74-
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
75-
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
76-
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
77-
pixel_vals[mask] = 0 # Replace masked values in output.
78-
return pixel_vals
10+
from .utils import interpolate_face_attributes
7911

8012

8113
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
@@ -97,8 +29,8 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
9729
relative to the faces (in the packed
9830
representation) which overlap the pixel.
9931
meshes: Meshes representing a batch of meshes. It is expected that
100-
meshes has a textures attribute which is an instance of the
101-
Textures class.
32+
meshes has a textures attribute which is an instance of the
33+
Textures class.
10234
10335
Returns:
10436
texels: tensor of shape (N, H, W, K, C) giving the interpolated
@@ -114,7 +46,9 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
11446
texture_maps = meshes.textures.maps_padded()
11547

11648
# pixel_uvs: (N, H, W, K, 2)
117-
pixel_uvs = interpolate_face_attributes(fragments, faces_verts_uvs)
49+
pixel_uvs = interpolate_face_attributes(
50+
fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
51+
)
11852

11953
N, H_out, W_out, K = fragments.pix_to_face.shape
12054
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
@@ -178,5 +112,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
178112
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
179113
faces_packed = meshes.faces_packed()
180114
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
181-
texels = interpolate_face_attributes(fragments, faces_textures)
115+
texels = interpolate_face_attributes(
116+
fragments.pix_to_face, fragments.bary_coords, faces_textures
117+
)
182118
return texels

pytorch3d/renderer/mesh/utils.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3+
4+
5+
import torch
6+
7+
8+
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
9+
"""
10+
Args:
11+
bary: barycentric coordinates of shape (...., 3) where `...` represents
12+
an arbitrary number of dimensions
13+
14+
Returns:
15+
bary: Barycentric coordinates clipped (i.e any values < 0 are set to 0)
16+
and renormalized. We only clip the negative values. Values > 1 will fall
17+
into the [0, 1] range after renormalization.
18+
The output is the same shape as the input.
19+
"""
20+
if bary.shape[-1] != 3:
21+
msg = "Expected barycentric coords to have last dim = 3; got %r"
22+
raise ValueError(msg % bary.shape)
23+
clipped = bary.clamp(min=0.0)
24+
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
25+
clipped = clipped / clipped_sum
26+
return clipped
27+
28+
29+
def interpolate_face_attributes(
30+
pix_to_face: torch.Tensor,
31+
barycentric_coords: torch.Tensor,
32+
face_attributes: torch.Tensor,
33+
) -> torch.Tensor:
34+
"""
35+
Interpolate arbitrary face attributes using the barycentric coordinates
36+
for each pixel in the rasterized output.
37+
38+
Args:
39+
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
40+
of the faces (in the packed representation) which
41+
overlap each pixel in the image.
42+
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
43+
the barycentric coordianates of each pixel
44+
relative to the faces (in the packed
45+
representation) which overlap the pixel.
46+
face_attributes: packed attributes of shape (total_faces, 3, D),
47+
specifying the value of the attribute for each
48+
vertex in the face.
49+
50+
Returns:
51+
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
52+
value of the face attribute for each pixel.
53+
"""
54+
F, FV, D = face_attributes.shape
55+
if FV != 3:
56+
raise ValueError("Faces can only have three vertices; got %r" % FV)
57+
N, H, W, K, _ = barycentric_coords.shape
58+
if pix_to_face.shape != (N, H, W, K):
59+
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
60+
raise ValueError(msg % pix_to_face.shape)
61+
62+
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
63+
mask = pix_to_face == -1
64+
pix_to_face = pix_to_face.clone()
65+
pix_to_face[mask] = 0
66+
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
67+
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
68+
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
69+
pixel_vals[mask] = 0 # Replace masked values in output.
70+
return pixel_vals
71+
72+
73+
def _interpolate_zbuf(
74+
pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes
75+
) -> torch.Tensor:
76+
"""
77+
A helper function to calculate the z buffer for each pixel in the
78+
rasterized output.
79+
80+
Args:
81+
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
82+
of the faces (in the packed representation) which
83+
overlap each pixel in the image.
84+
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
85+
the barycentric coordianates of each pixel
86+
relative to the faces (in the packed
87+
representation) which overlap the pixel.
88+
meshes: Meshes object representing a batch of meshes.
89+
90+
Returns:
91+
zbuffer: (N, H, W, K) FloatTensor
92+
"""
93+
verts = meshes.verts_packed()
94+
faces = meshes.faces_packed()
95+
faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1)
96+
return interpolate_face_attributes(
97+
pix_to_face, barycentric_coords, faces_verts_z
98+
)[
99+
..., 0
100+
] # (1, H, W, K)
45.1 KB
Loading
26.1 KB
Loading
Loading

tests/test_mesh_rendering_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3+
4+
5+
import unittest
6+
import torch
7+
8+
from pytorch3d.renderer.mesh.utils import _clip_barycentric_coordinates
9+
10+
11+
class TestMeshRenderingUtils(unittest.TestCase):
12+
def test_bary_clip(self):
13+
N = 10
14+
bary = torch.randn(size=(N, 3))
15+
# randomly make some values negative
16+
bary[bary < 0.3] *= -1.0
17+
# randomly make some values be greater than 1
18+
bary[bary > 0.8] *= 2.0
19+
negative_mask = bary < 0.0
20+
positive_mask = bary > 1.0
21+
clipped = _clip_barycentric_coordinates(bary)
22+
self.assertTrue(clipped[negative_mask].sum() == 0)
23+
self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0)
24+
self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N)))

0 commit comments

Comments
 (0)