Skip to content

Commit a393296

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Texturing API updates
Summary: A fairly big refactor of the texturing API with some breaking changes to how textures are defined. Main changes: - There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class: - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub. - has a `join_batch` method for joining multiple textures of the same type into a batch Reviewed By: gkioxari Differential Revision: D21067427 fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
1 parent b73d3d6 commit a393296

19 files changed

+1864
-777
lines changed

docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,9 @@
520520
],
521521
"metadata": {
522522
"accelerator": "GPU",
523+
"anp_metadata": {
524+
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb"
525+
},
523526
"bento_stylesheets": {
524527
"bento/extensions/flow/main.css": true,
525528
"bento/extensions/kernel_selector/main.css": true,
@@ -533,6 +536,9 @@
533536
"provenance": [],
534537
"toc_visible": true
535538
},
539+
"disseminate_notebook_info": {
540+
"backup_notebook_id": "1062179640844868"
541+
},
536542
"kernelspec": {
537543
"display_name": "pytorch3d (local)",
538544
"language": "python",

docs/tutorials/render_textured_meshes.ipynb

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
"from skimage.io import imread\n",
8585
"\n",
8686
"# Util function for loading meshes\n",
87-
"from pytorch3d.io import load_objs_as_meshes\n",
87+
"from pytorch3d.io import load_objs_as_meshes, load_obj\n",
8888
"\n",
8989
"# Data structures and functions for rendering\n",
9090
"from pytorch3d.structures import Meshes, Textures\n",
@@ -97,7 +97,7 @@
9797
" RasterizationSettings, \n",
9898
" MeshRenderer, \n",
9999
" MeshRasterizer, \n",
100-
" TexturedSoftPhongShader\n",
100+
" SoftPhongShader\n",
101101
")\n",
102102
"\n",
103103
"# add path for demo utils functions \n",
@@ -316,7 +316,7 @@
316316
" cameras=cameras, \n",
317317
" raster_settings=raster_settings\n",
318318
" ),\n",
319-
" shader=TexturedSoftPhongShader(\n",
319+
" shader=SoftPhongShader(\n",
320320
" device=device, \n",
321321
" cameras=cameras,\n",
322322
" lights=lights\n",
@@ -563,6 +563,9 @@
563563
],
564564
"metadata": {
565565
"accelerator": "GPU",
566+
"anp_metadata": {
567+
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/render_textured_meshes.ipynb"
568+
},
566569
"bento_stylesheets": {
567570
"bento/extensions/flow/main.css": true,
568571
"bento/extensions/kernel_selector/main.css": true,
@@ -575,6 +578,9 @@
575578
"name": "render_textured_meshes.ipynb",
576579
"provenance": []
577580
},
581+
"disseminate_notebook_info": {
582+
"backup_notebook_id": "569222367081034"
583+
},
578584
"kernelspec": {
579585
"display_name": "pytorch3d (local)",
580586
"language": "python",

pytorch3d/datasets/shapenet_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
OpenGLPerspectiveCameras,
1414
PointLights,
1515
RasterizationSettings,
16+
TexturesVertex,
1617
)
17-
from pytorch3d.structures import Textures
1818

1919

2020
class ShapeNetBase(torch.utils.data.Dataset):
@@ -113,8 +113,8 @@ def render(
113113
"""
114114
paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
115115
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
116-
meshes.textures = Textures(
117-
verts_rgb=torch.ones_like(meshes.verts_padded(), device=device)
116+
meshes.textures = TexturesVertex(
117+
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
118118
)
119119
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
120120
renderer = MeshRenderer(

pytorch3d/io/obj_io.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import torch
1212
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
1313
from pytorch3d.io.utils import _open_file
14-
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
14+
from pytorch3d.renderer import TexturesAtlas, TexturesUV
15+
from pytorch3d.structures import Meshes, join_meshes_as_batch
1516

1617

1718
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
@@ -41,6 +42,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
4142
Args:
4243
faces_indices: List of ints of indices.
4344
max_index: Max index for the face property.
45+
pad_value: if any of the face_indices are padded, specify
46+
the value of the padding (e.g. -1). This is only used
47+
for texture indices indices where there might
48+
not be texture information for all the faces.
4449
4550
Returns:
4651
faces_indices: List of ints of indices.
@@ -65,7 +70,9 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
6570
faces_indices[mask] = pad_value
6671

6772
# Check indices are valid.
68-
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
73+
if torch.any(faces_indices >= max_index) or (
74+
pad_value is None and torch.any(faces_indices < 0)
75+
):
6976
warnings.warn("Faces have invalid indices")
7077

7178
return faces_indices
@@ -227,7 +234,14 @@ def load_obj(
227234
)
228235

229236

230-
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
237+
def load_objs_as_meshes(
238+
files: list,
239+
device=None,
240+
load_textures: bool = True,
241+
create_texture_atlas: bool = False,
242+
texture_atlas_size: int = 4,
243+
texture_wrap: Optional[str] = "repeat",
244+
):
231245
"""
232246
Load meshes from a list of .obj files using the load_obj function, and
233247
return them as a Meshes object. This only works for meshes which have a
@@ -246,18 +260,31 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
246260
"""
247261
mesh_list = []
248262
for f_obj in files:
249-
# TODO: update this function to support the two texturing options.
250-
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
251-
verts = verts.to(device)
263+
verts, faces, aux = load_obj(
264+
f_obj,
265+
load_textures=load_textures,
266+
create_texture_atlas=create_texture_atlas,
267+
texture_atlas_size=texture_atlas_size,
268+
texture_wrap=texture_wrap,
269+
)
252270
tex = None
253-
tex_maps = aux.texture_images
254-
if tex_maps is not None and len(tex_maps) > 0:
255-
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
256-
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
257-
image = list(tex_maps.values())[0].to(device)[None]
258-
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
259-
260-
mesh = Meshes(verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex)
271+
if create_texture_atlas:
272+
# TexturesAtlas type
273+
tex = TexturesAtlas(atlas=[aux.texture_atlas])
274+
else:
275+
# TexturesUV type
276+
tex_maps = aux.texture_images
277+
if tex_maps is not None and len(tex_maps) > 0:
278+
verts_uvs = aux.verts_uvs.to(device) # (V, 2)
279+
faces_uvs = faces.textures_idx.to(device) # (F, 3)
280+
image = list(tex_maps.values())[0].to(device)[None]
281+
tex = TexturesUV(
282+
verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image
283+
)
284+
285+
mesh = Meshes(
286+
verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex
287+
)
261288
mesh_list.append(mesh)
262289
if len(mesh_list) == 1:
263290
return mesh_list[0]

pytorch3d/renderer/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
SoftGouraudShader,
2929
SoftPhongShader,
3030
SoftSilhouetteShader,
31-
TexturedSoftPhongShader,
31+
Textures,
32+
TexturesAtlas,
33+
TexturesUV,
34+
TexturesVertex,
3235
gouraud_shading,
33-
interpolate_face_attributes,
34-
interpolate_texture_map,
35-
interpolate_vertex_colors,
3636
phong_shading,
3737
rasterize_meshes,
3838
)

pytorch3d/renderer/mesh/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33

4-
from .texturing import interpolate_texture_map, interpolate_vertex_colors # isort:skip
54
from .rasterize_meshes import rasterize_meshes
65
from .rasterizer import MeshRasterizer, RasterizationSettings
76
from .renderer import MeshRenderer
7+
from .shader import TexturedSoftPhongShader # DEPRECATED
88
from .shader import (
99
HardFlatShader,
1010
HardGouraudShader,
1111
HardPhongShader,
1212
SoftGouraudShader,
1313
SoftPhongShader,
1414
SoftSilhouetteShader,
15-
TexturedSoftPhongShader,
1615
)
1716
from .shading import gouraud_shading, phong_shading
18-
from .utils import interpolate_face_attributes
17+
from .textures import Textures # DEPRECATED
18+
from .textures import TexturesAtlas, TexturesUV, TexturesVertex
1919

2020

2121
__all__ = [k for k in globals().keys() if not k.startswith("_")]

pytorch3d/renderer/mesh/shader.py

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
import warnings
34

45
import torch
56
import torch.nn as nn
@@ -13,7 +14,6 @@
1314
from ..lighting import PointLights
1415
from ..materials import Materials
1516
from .shading import flat_shading, gouraud_shading, phong_shading
16-
from .texturing import interpolate_texture_map, interpolate_vertex_colors
1717

1818

1919
# A Shader should take as input fragments from the output of rasterization
@@ -57,7 +57,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
5757
or in the forward pass of HardPhongShader"
5858
raise ValueError(msg)
5959

60-
texels = interpolate_vertex_colors(fragments, meshes)
60+
texels = meshes.sample_textures(fragments)
6161
lights = kwargs.get("lights", self.lights)
6262
materials = kwargs.get("materials", self.materials)
6363
blend_params = kwargs.get("blend_params", self.blend_params)
@@ -104,9 +104,11 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
104104
msg = "Cameras must be specified either at initialization \
105105
or in the forward pass of SoftPhongShader"
106106
raise ValueError(msg)
107-
texels = interpolate_vertex_colors(fragments, meshes)
107+
108+
texels = meshes.sample_textures(fragments)
108109
lights = kwargs.get("lights", self.lights)
109110
materials = kwargs.get("materials", self.materials)
111+
blend_params = kwargs.get("blend_params", self.blend_params)
110112
colors = phong_shading(
111113
meshes=meshes,
112114
fragments=fragments,
@@ -115,7 +117,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
115117
cameras=cameras,
116118
materials=materials,
117119
)
118-
images = softmax_rgb_blend(colors, fragments, self.blend_params)
120+
images = softmax_rgb_blend(colors, fragments, blend_params)
119121
return images
120122

121123

@@ -154,6 +156,12 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
154156
lights = kwargs.get("lights", self.lights)
155157
materials = kwargs.get("materials", self.materials)
156158
blend_params = kwargs.get("blend_params", self.blend_params)
159+
160+
# As Gouraud shading applies the illumination to the vertex
161+
# colors, the interpolated pixel texture is calculated in the
162+
# shading step. In comparison, for Phong shading, the pixel
163+
# textures are computed first after which the illumination is
164+
# applied.
157165
pixel_colors = gouraud_shading(
158166
meshes=meshes,
159167
fragments=fragments,
@@ -210,54 +218,25 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
210218
return images
211219

212220

213-
class TexturedSoftPhongShader(nn.Module):
221+
def TexturedSoftPhongShader(
222+
device="cpu", cameras=None, lights=None, materials=None, blend_params=None
223+
):
214224
"""
215-
Per pixel lighting applied to a texture map. First interpolate the vertex
216-
uv coordinates and sample from a texture map. Then apply the lighting model
217-
using the interpolated coords and normals for each pixel.
218-
219-
The blending function returns the soft aggregated color using all
220-
the faces per pixel.
221-
222-
To use the default values, simply initialize the shader with the desired
223-
device e.g.
224-
225-
.. code-block::
226-
227-
shader = TexturedPhongShader(device=torch.device("cuda:0"))
225+
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
226+
Preserving TexturedSoftPhongShader as a function for backwards compatibility.
228227
"""
229-
230-
def __init__(
231-
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
232-
):
233-
super().__init__()
234-
self.lights = lights if lights is not None else PointLights(device=device)
235-
self.materials = (
236-
materials if materials is not None else Materials(device=device)
237-
)
238-
self.cameras = cameras
239-
self.blend_params = blend_params if blend_params is not None else BlendParams()
240-
241-
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
242-
cameras = kwargs.get("cameras", self.cameras)
243-
if cameras is None:
244-
msg = "Cameras must be specified either at initialization \
245-
or in the forward pass of TexturedSoftPhongShader"
246-
raise ValueError(msg)
247-
texels = interpolate_texture_map(fragments, meshes)
248-
lights = kwargs.get("lights", self.lights)
249-
materials = kwargs.get("materials", self.materials)
250-
blend_params = kwargs.get("blend_params", self.blend_params)
251-
colors = phong_shading(
252-
meshes=meshes,
253-
fragments=fragments,
254-
texels=texels,
255-
lights=lights,
256-
cameras=cameras,
257-
materials=materials,
258-
)
259-
images = softmax_rgb_blend(colors, fragments, blend_params)
260-
return images
228+
warnings.warn(
229+
"""TexturedSoftPhongShader is now deprecated;
230+
use SoftPhongShader instead.""",
231+
PendingDeprecationWarning,
232+
)
233+
return SoftPhongShader(
234+
device=device,
235+
cameras=cameras,
236+
lights=lights,
237+
materials=materials,
238+
blend_params=blend_params,
239+
)
261240

262241

263242
class HardFlatShader(nn.Module):
@@ -291,7 +270,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
291270
msg = "Cameras must be specified either at initialization \
292271
or in the forward pass of HardFlatShader"
293272
raise ValueError(msg)
294-
texels = interpolate_vertex_colors(fragments, meshes)
273+
texels = meshes.sample_textures(fragments)
295274
lights = kwargs.get("lights", self.lights)
296275
materials = kwargs.get("materials", self.materials)
297276
blend_params = kwargs.get("blend_params", self.blend_params)

0 commit comments

Comments
 (0)