Skip to content

Commit 8d10ba5

Browse files
d4l3kfacebook-github-bot
authored andcommitted
renderer: add support for rendering high dimensional textures for classification/segmentation use cases (#1248)
Summary: For 3D segmentation problems it's really useful to be able to train the models from multiple viewpoints using Pytorch3D as the renderer. Currently due to hardcoded assumptions in a few spots the mesh renderer only supports rendering RGB (3 dimensional) data. You can encode the classification information as 3 channel data but if you have more than 3 classes you're out of luck. This relaxes the assumptions to make rendering semantic classes work with `HardFlatShader` and `AmbientLights` with no diffusion/specular. The other shaders/lights don't make any sense for classification since they mutate the texture values in some way. This only requires changes in `Materials` and `AmbientLights`. The bulk of the code is the unit test. Pull Request resolved: #1248 Test Plan: Added unit test that renders a 5 dimensional texture and compare dimensions 2-5 to a stored picture. Reviewed By: bottler Differential Revision: D37764610 Pulled By: d4l3k fbshipit-source-id: 031895724d9318a6f6bab5b31055bb3f438176a5
1 parent aa8b03f commit 8d10ba5

File tree

4 files changed

+101
-12
lines changed

4 files changed

+101
-12
lines changed

pytorch3d/renderer/lighting.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ class AmbientLights(TensorProperties):
292292
A light object representing the same color of light everywhere.
293293
By default, this is white, which effectively means lighting is
294294
not used in rendering.
295+
296+
Unlike other lights this supports an arbitrary number of channels, not just 3 for RGB.
297+
The ambient_color input determines the number of channels.
295298
"""
296299

297300
def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
@@ -304,9 +307,11 @@ def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
304307
device: Device (as str or torch.device) on which the tensors should be located
305308
306309
The ambient_color if provided, should be
307-
- 3 element tuple/list or list of lists
308-
- torch tensor of shape (1, 3)
309-
- torch tensor of shape (N, 3)
310+
- tuple/list of C-element tuples of floats
311+
- torch tensor of shape (1, C)
312+
- torch tensor of shape (N, C)
313+
where C is the number of channels and N is batch size.
314+
For RGB, C is 3.
310315
"""
311316
if ambient_color is None:
312317
ambient_color = ((1.0, 1.0, 1.0),)
@@ -317,10 +322,14 @@ def clone(self):
317322
return super().clone(other)
318323

319324
def diffuse(self, normals, points) -> torch.Tensor:
320-
return torch.zeros_like(points)
325+
return self._zeros_channels(points)
321326

322327
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
323-
return torch.zeros_like(points)
328+
return self._zeros_channels(points)
329+
330+
def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
331+
ch = self.ambient_color.shape[-1]
332+
return torch.zeros(*points.shape[:-1], ch, device=points.device)
324333

325334

326335
def _validate_light_properties(obj) -> None:

pytorch3d/renderer/materials.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,18 @@ def __init__(
2727
) -> None:
2828
"""
2929
Args:
30-
ambient_color: RGB ambient reflectivity of the material
31-
diffuse_color: RGB diffuse reflectivity of the material
32-
specular_color: RGB specular reflectivity of the material
30+
ambient_color: ambient reflectivity of the material
31+
diffuse_color: diffuse reflectivity of the material
32+
specular_color: specular reflectivity of the material
3333
shininess: The specular exponent for the material. This defines
3434
the focus of the specular highlight with a high value
3535
resulting in a concentrated highlight. Shininess values
3636
can range from 0-1000.
3737
device: Device (as str or torch.device) on which the tensors should be located
3838
3939
ambient_color, diffuse_color and specular_color can be of shape
40-
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
40+
(1, C) or (N, C) where C is typically 3 (for RGB). shininess can be of shape (1,)
41+
or (N,).
4142
4243
The colors and shininess are broadcast against each other so need to
4344
have either the same batch dimension or batch dimension = 1.
@@ -49,11 +50,12 @@ def __init__(
4950
specular_color=specular_color,
5051
shininess=shininess,
5152
)
53+
C = self.ambient_color.shape[-1]
5254
for n in ["ambient_color", "diffuse_color", "specular_color"]:
5355
t = getattr(self, n)
54-
if t.shape[-1] != 3:
55-
msg = "Expected %s to have shape (N, 3); got %r"
56-
raise ValueError(msg % (n, t.shape))
56+
if t.shape[-1] != C:
57+
msg = "Expected %s to have shape (N, %d); got %r"
58+
raise ValueError(msg % (n, C, t.shape))
5759
if self.shininess.shape != torch.Size([self._N]):
5860
msg = "shininess should have shape (N); got %r"
5961
raise ValueError(msg % repr(self.shininess.shape))

tests/data/test_nd_sphere.png

55.9 KB
Loading

tests/test_render_meshes.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,3 +1236,81 @@ def test_cameras_kwarg(self):
12361236
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
12371237
)
12381238
self.assertClose(rgb, image_ref, atol=0.05)
1239+
1240+
def test_nd_sphere(self):
1241+
"""
1242+
Test that the render can handle textures with more than 3 channels and
1243+
not just 3 channel RGB.
1244+
"""
1245+
torch.manual_seed(1)
1246+
device = torch.device("cuda:0")
1247+
C = 5
1248+
WHITE = ((1.0,) * C,)
1249+
BLACK = ((0.0,) * C,)
1250+
1251+
# Init mesh
1252+
sphere_mesh = ico_sphere(5, device)
1253+
verts_padded = sphere_mesh.verts_padded()
1254+
faces_padded = sphere_mesh.faces_padded()
1255+
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
1256+
n_verts = feats.shape[1]
1257+
# make some non-uniform pattern
1258+
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
1259+
textures = TexturesVertex(verts_features=feats)
1260+
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
1261+
1262+
# No elevation or azimuth rotation
1263+
R, T = look_at_view_transform(2.7, 0.0, 0.0)
1264+
1265+
cameras = PerspectiveCameras(device=device, R=R, T=T)
1266+
1267+
# Init shader settings
1268+
materials = Materials(
1269+
device=device,
1270+
ambient_color=WHITE,
1271+
diffuse_color=WHITE,
1272+
specular_color=WHITE,
1273+
)
1274+
lights = AmbientLights(
1275+
device=device,
1276+
ambient_color=WHITE,
1277+
)
1278+
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
1279+
1280+
raster_settings = RasterizationSettings(
1281+
image_size=512, blur_radius=0.0, faces_per_pixel=1
1282+
)
1283+
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
1284+
blend_params = BlendParams(
1285+
1e-4,
1286+
1e-4,
1287+
background_color=BLACK[0],
1288+
)
1289+
1290+
# only test HardFlatShader since that's the only one that makes
1291+
# sense for classification
1292+
shader = HardFlatShader(
1293+
lights=lights,
1294+
cameras=cameras,
1295+
materials=materials,
1296+
blend_params=blend_params,
1297+
)
1298+
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
1299+
images = renderer(sphere_mesh)
1300+
1301+
self.assertEqual(images.shape[-1], C + 1)
1302+
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
1303+
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
1304+
1305+
# grab last 3 color channels
1306+
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
1307+
filename = "test_nd_sphere.png"
1308+
1309+
if DEBUG:
1310+
debug_filename = "DEBUG_%s" % filename
1311+
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
1312+
DATA_DIR / debug_filename
1313+
)
1314+
1315+
image_ref = load_rgb_image(filename, DATA_DIR)
1316+
self.assertClose(rgb, image_ref, atol=0.05)

0 commit comments

Comments
 (0)