Skip to content

Support reading uv and uv map for ply format if texture_uv exists in ply file #1100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions pytorch3d/io/ply_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import struct
import sys
import warnings
from os import path as osp
from collections import namedtuple
from dataclasses import asdict, dataclass
from io import BytesIO, TextIOBase
Expand All @@ -21,8 +22,8 @@
import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesVertex
from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _make_tensor, _open_file, _read_image
from pytorch3d.renderer import TexturesVertex, TexturesUV
from pytorch3d.structures import Meshes, Pointclouds

from .pluggable_formats import (
Expand Down Expand Up @@ -804,6 +805,7 @@ class _VertsColumnIndices:
color_idxs: Optional[List[int]]
color_scale: float
normal_idxs: Optional[List[int]]
texture_uv_idxs: Optional[List[int]]


def _get_verts_column_indices(
Expand All @@ -827,6 +829,8 @@ def _get_verts_column_indices(
property uchar red
property uchar green
property uchar blue
property double texture_u
property double texture_v

then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])

Expand All @@ -839,6 +843,7 @@ def _get_verts_column_indices(
point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None]
normal_idxs: List[Optional[int]] = [None, None, None]
texture_uv_idxs : List[Optional[int]] = [None, None]
for i, prop in enumerate(vertex_head.properties):
if prop.list_size_type is not None:
raise ValueError("Invalid vertices in file: did not expect list.")
Expand All @@ -851,6 +856,9 @@ def _get_verts_column_indices(
for j, name in enumerate(["nx", "ny", "nz"]):
if prop.name == name:
normal_idxs[j] = i
for j, name in enumerate(["texture_u", "texture_v"]):
if prop.name == name:
texture_uv_idxs[j] = i
if None in point_idxs:
raise ValueError("Invalid vertices in file.")
color_scale = 1.0
Expand All @@ -864,6 +872,7 @@ def _get_verts_column_indices(
color_idxs=None if None in color_idxs else color_idxs,
color_scale=color_scale,
normal_idxs=None if None in normal_idxs else normal_idxs,
texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs,
)


Expand All @@ -880,6 +889,7 @@ class _VertsData:
verts: torch.Tensor
verts_colors: Optional[torch.Tensor] = None
verts_normals: Optional[torch.Tensor] = None
verts_texture_uvs: Optional[torch.Tensor] = None


def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
Expand Down Expand Up @@ -922,6 +932,7 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:

vertex_colors = None
vertex_normals = None
vertex_texture_uvs = None

if len(vertex) == 1:
# This is the case where the whole vertex element has one type,
Expand All @@ -935,6 +946,10 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
vertex_normals = torch.tensor(
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
)
if column_idxs.texture_uv_idxs is not None:
vertex_texture_uvs = torch.tensor(
vertex[0][:, column_idxs.texture_uv_idxs], dtype=torch.float32
)
else:
# The vertex element is heterogeneous. It was read as several arrays,
# part by part, where a part is a set of properties with the same type.
Expand Down Expand Up @@ -973,11 +988,18 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
for axis in range(3):
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]

if column_idxs.texture_uv_idxs is not None:
vertex_texture_uvs = torch.empty(
size=(vertex_head.count, 2), dtype=torch.float32,
)
for axis in range(2):
partnum, col = prop_to_partnum_col[column_idxs.texture_uv_idxs[axis]]
vertex_texture_uvs.numpy()[:, axis] = vertex[partnum][:, col]
return _VertsData(
verts=verts,
verts_colors=vertex_colors,
verts_normals=vertex_normals,
verts_texture_uvs=vertex_texture_uvs,
)


Expand All @@ -998,6 +1020,7 @@ class _PlyData:
faces: Optional[torch.Tensor]
verts_colors: Optional[torch.Tensor]
verts_normals: Optional[torch.Tensor]
verts_texture_uvs : Optional[torch.Tensor]


def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
Expand Down Expand Up @@ -1358,8 +1381,24 @@ def read(
faces = torch.zeros(0, 3, dtype=torch.int64)

texture = None
if include_textures and data.verts_colors is not None:
texture = TexturesVertex([data.verts_colors.to(device)])
if include_textures:
if data.verts_colors is not None:
texture = TexturesVertex([data.verts_colors.to(device)])
elif data.verts_texture_uvs is not None:
texture_file_path = None
for comment in data.header.comments:
if 'TextureFile' in comment:
texture_file_path = comment.split(' ')[-1]
texture_file_path = osp.join(osp.dirname(path), texture_file_path)
if texture_file_path is not None:
texture_map = _read_image(texture_file_path, path_manager, format='RGB')
texture_map = torch.tensor(texture_map, dtype=torch.float32) / 255.
texture = TexturesUV(
[texture_map.to(device)], [faces.to(device)], [data.verts_texture_uvs.to(device)])
else:
texture = TexturesVertex([torch.ones_like(data.verts).to(device)])
print('Warning: No texture found, init the texture with white color')


verts_normals = None
if data.verts_normals is not None:
Expand Down
4 changes: 3 additions & 1 deletion pytorch3d/renderer/mesh/rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def __init__(self, cameras=None, raster_settings=None) -> None:

def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
self.cameras = self.cameras.to(device)
cameras = self.cameras
if cameras is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a significant bug you are fixing, which is separate from this change. Thank you. It would be good to put it in its own PR which would include a test for this problem.

(Or let us know and we can do it internally.)

self.cameras = cameras.to(device)
return self

def transform(self, meshes_world, **kwargs) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/renderer/mesh/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
faces_uvs_list += self.faces_uvs_list()
verts_uvs_list += self.verts_uvs_list()
maps_list += self.maps_list()
num_faces_per_mesh = self._num_faces_per_mesh
num_faces_per_mesh = self._num_faces_per_mesh.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also looks like a separate significant bug you've fixed, best in its own PR. Thanks again. Probably the test change would be an extra assertion in an existing test.

(Or let us know and we can sort it out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I think TexturesAtlas has exactly the same bug. They should be fixed at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a significant bug which was found when I was using Pytorch3D. Specifically, for meshes with TextureUV, calling function join_meshes_as_batch will cause the origin mesh textureUV's _num_faces_per_mesh changed, due to the list reference attribute. Especially when calling join_meshes_per_mesh multiple times, this will cause OOM finally, because the original mesh textureUV's _num_faces_per_mesh list has became too long. I have not noticed the TextureAltas has the same bug because I have not used it. This is just a small change, so I committed it together. Can you please sort it out if needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Let me sort it out. Best to remove this change from this PR still.

for tex in textures:
verts_uvs_list += tex.verts_uvs_list()
faces_uvs_list += tex.faces_uvs_list()
Expand Down