Skip to content

Commit 0c595dc

Browse files
bottlerfacebook-github-bot
authored andcommitted
Joining mismatched texture maps on CUDA #175
Summary: Use nn.functional.interpolate instead of a TorchVision transform to resize texture maps to a common value. This works on all devices. This fixes issue #175. Also fix the condition so it only happens when needed. Reviewed By: nikhilaravi Differential Revision: D21324510 fbshipit-source-id: c50eb06514984995bd81f2c44079be6e0b4098e4
1 parent e64e0d1 commit 0c595dc

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

pytorch3d/structures/textures.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Optional, Union
44

55
import torch
6-
import torchvision.transforms as T
6+
from torch.nn.functional import interpolate
77

88
from .utils import padded_to_list, padded_to_packed
99

@@ -18,10 +18,10 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
1818
Pad all texture images so they have the same height and width.
1919
2020
Args:
21-
images: list of N tensors of shape (H, W)
21+
images: list of N tensors of shape (H, W, 3)
2222
2323
Returns:
24-
tex_maps: Tensor of shape (N, max_H, max_W)
24+
tex_maps: Tensor of shape (N, max_H, max_W, 3)
2525
"""
2626
tex_maps = []
2727
max_H = 0
@@ -35,15 +35,13 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
3535
tex_maps.append(im)
3636
max_shape = (max_H, max_W)
3737

38-
# If all texture images are not the same size then resize to the
39-
# largest size.
40-
resize = T.Compose([T.ToPILImage(), T.Resize(size=max_shape), T.ToTensor()])
41-
4238
for i, image in enumerate(tex_maps):
43-
if image.shape != max_shape:
44-
# ToPIL takes and returns a C x H x W tensor
45-
image = resize(image.permute(2, 0, 1)).permute(1, 2, 0)
46-
tex_maps[i] = image
39+
if image.shape[:2] != max_shape:
40+
image_BCHW = image.permute(2, 0, 1)[None]
41+
new_image_BCHW = interpolate(
42+
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
43+
)
44+
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
4745
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
4846
return tex_maps
4947

tests/test_obj_io.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,12 @@ def check_item(x, y):
607607
check_triple(mesh, mesh3)
608608
self.assertTupleEqual(mesh.textures.maps_padded().shape, (1, 1024, 1024, 3))
609609

610+
# Try mismatched texture map sizes, which needs a call to interpolate()
611+
mesh2048 = mesh.clone()
612+
maps = mesh.textures.maps_padded()
613+
mesh2048.textures._maps_padded = torch.cat([maps, maps], dim=1)
614+
join_meshes_as_batch([mesh.to("cuda:0"), mesh2048.to("cuda:0")])
615+
610616
mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False)
611617
mesh3_notex = load_objs_as_meshes(
612618
[obj_filename, obj_filename, obj_filename], load_textures=False

0 commit comments

Comments
 (0)