diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 37c86517f..020d5eae9 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -13,6 +13,7 @@ import struct import sys import warnings +from os import path as osp from collections import namedtuple from dataclasses import asdict, dataclass from io import BytesIO, TextIOBase @@ -21,8 +22,8 @@ import numpy as np import torch from iopath.common.file_io import PathManager -from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _make_tensor, _open_file -from pytorch3d.renderer import TexturesVertex +from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _make_tensor, _open_file, _read_image +from pytorch3d.renderer import TexturesVertex, TexturesUV from pytorch3d.structures import Meshes, Pointclouds from .pluggable_formats import ( @@ -804,6 +805,7 @@ class _VertsColumnIndices: color_idxs: Optional[List[int]] color_scale: float normal_idxs: Optional[List[int]] + texture_uv_idxs: Optional[List[int]] def _get_verts_column_indices( @@ -827,6 +829,8 @@ def _get_verts_column_indices( property uchar red property uchar green property uchar blue + property double texture_u + property double texture_v then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5]) @@ -839,6 +843,7 @@ def _get_verts_column_indices( point_idxs: List[Optional[int]] = [None, None, None] color_idxs: List[Optional[int]] = [None, None, None] normal_idxs: List[Optional[int]] = [None, None, None] + texture_uv_idxs : List[Optional[int]] = [None, None] for i, prop in enumerate(vertex_head.properties): if prop.list_size_type is not None: raise ValueError("Invalid vertices in file: did not expect list.") @@ -851,6 +856,9 @@ def _get_verts_column_indices( for j, name in enumerate(["nx", "ny", "nz"]): if prop.name == name: normal_idxs[j] = i + for j, name in enumerate(["texture_u", "texture_v"]): + if prop.name == name: + texture_uv_idxs[j] = i if None in point_idxs: raise ValueError("Invalid vertices in file.") color_scale = 1.0 @@ -864,6 +872,7 @@ def _get_verts_column_indices( color_idxs=None if None in color_idxs else color_idxs, color_scale=color_scale, normal_idxs=None if None in normal_idxs else normal_idxs, + texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs, ) @@ -880,6 +889,7 @@ class _VertsData: verts: torch.Tensor verts_colors: Optional[torch.Tensor] = None verts_normals: Optional[torch.Tensor] = None + verts_texture_uvs: Optional[torch.Tensor] = None def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: @@ -922,6 +932,7 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: vertex_colors = None vertex_normals = None + vertex_texture_uvs = None if len(vertex) == 1: # This is the case where the whole vertex element has one type, @@ -935,6 +946,10 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: vertex_normals = torch.tensor( vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32 ) + if column_idxs.texture_uv_idxs is not None: + vertex_texture_uvs = torch.tensor( + vertex[0][:, column_idxs.texture_uv_idxs], dtype=torch.float32 + ) else: # The vertex element is heterogeneous. It was read as several arrays, # part by part, where a part is a set of properties with the same type. @@ -973,11 +988,18 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: for axis in range(3): partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]] vertex_normals.numpy()[:, axis] = vertex[partnum][:, col] - + if column_idxs.texture_uv_idxs is not None: + vertex_texture_uvs = torch.empty( + size=(vertex_head.count, 2), dtype=torch.float32, + ) + for axis in range(2): + partnum, col = prop_to_partnum_col[column_idxs.texture_uv_idxs[axis]] + vertex_texture_uvs.numpy()[:, axis] = vertex[partnum][:, col] return _VertsData( verts=verts, verts_colors=vertex_colors, verts_normals=vertex_normals, + verts_texture_uvs=vertex_texture_uvs, ) @@ -998,6 +1020,7 @@ class _PlyData: faces: Optional[torch.Tensor] verts_colors: Optional[torch.Tensor] verts_normals: Optional[torch.Tensor] + verts_texture_uvs : Optional[torch.Tensor] def _load_ply(f, *, path_manager: PathManager) -> _PlyData: @@ -1358,8 +1381,24 @@ def read( faces = torch.zeros(0, 3, dtype=torch.int64) texture = None - if include_textures and data.verts_colors is not None: - texture = TexturesVertex([data.verts_colors.to(device)]) + if include_textures: + if data.verts_colors is not None: + texture = TexturesVertex([data.verts_colors.to(device)]) + elif data.verts_texture_uvs is not None: + texture_file_path = None + for comment in data.header.comments: + if 'TextureFile' in comment: + texture_file_path = comment.split(' ')[-1] + texture_file_path = osp.join(osp.dirname(path), texture_file_path) + if texture_file_path is not None: + texture_map = _read_image(texture_file_path, path_manager, format='RGB') + texture_map = torch.tensor(texture_map, dtype=torch.float32) / 255. + texture = TexturesUV( + [texture_map.to(device)], [faces.to(device)], [data.verts_texture_uvs.to(device)]) + else: + texture = TexturesVertex([torch.ones_like(data.verts).to(device)]) + print('Warning: No texture found, init the texture with white color') + verts_normals = None if data.verts_normals is not None: diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index bd73f16d6..d60c2f70b 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -110,7 +110,9 @@ def __init__(self, cameras=None, raster_settings=None) -> None: def to(self, device): # Manually move to device cameras as it is not a subclass of nn.Module - self.cameras = self.cameras.to(device) + cameras = self.cameras + if cameras is not None: + self.cameras = cameras.to(device) return self def transform(self, meshes_world, **kwargs) -> torch.Tensor: diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index c9f4308f2..9454156cd 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -1073,7 +1073,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV": faces_uvs_list += self.faces_uvs_list() verts_uvs_list += self.verts_uvs_list() maps_list += self.maps_list() - num_faces_per_mesh = self._num_faces_per_mesh + num_faces_per_mesh = self._num_faces_per_mesh.copy() for tex in textures: verts_uvs_list += tex.verts_uvs_list() faces_uvs_list += tex.faces_uvs_list()