Skip to content

Commit 89532a8

Browse files
bottlerfacebook-github-bot
authored andcommitted
add existing mesh formats to pluggable
Summary: We already have code for obj and ply formats. Here we actually make it available in `IO.load_mesh` and `IO.save_mesh`. Reviewed By: theschnitz, nikhilaravi Differential Revision: D25400650 fbshipit-source-id: f26d6d7fc46c48634a948eea4d255afad13b807b
1 parent b183dcb commit 89532a8

File tree

5 files changed

+271
-65
lines changed

5 files changed

+271
-65
lines changed

pytorch3d/io/obj_io.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import os
66
import warnings
77
from collections import namedtuple
8-
from typing import List, Optional
8+
from pathlib import Path
9+
from typing import List, Optional, Union
910

1011
import numpy as np
1112
import torch
@@ -15,6 +16,8 @@
1516
from pytorch3d.renderer import TexturesAtlas, TexturesUV
1617
from pytorch3d.structures import Meshes, join_meshes_as_batch
1718

19+
from .pluggable_formats import MeshFormatInterpreter, endswith
20+
1821

1922
# Faces & Aux type returned from load_obj function.
2023
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
@@ -286,6 +289,58 @@ def load_objs_as_meshes(
286289
return join_meshes_as_batch(mesh_list)
287290

288291

292+
class MeshObjFormat(MeshFormatInterpreter):
293+
def __init__(self):
294+
self.known_suffixes = (".obj",)
295+
296+
def read(
297+
self,
298+
path: Union[str, Path],
299+
include_textures: bool,
300+
device,
301+
path_manager: PathManager,
302+
create_texture_atlas: bool = False,
303+
texture_atlas_size: int = 4,
304+
texture_wrap: Optional[str] = "repeat",
305+
**kwargs,
306+
) -> Optional[Meshes]:
307+
if not endswith(path, self.known_suffixes):
308+
return None
309+
mesh = load_objs_as_meshes(
310+
files=[path],
311+
device=device,
312+
load_textures=include_textures,
313+
create_texture_atlas=create_texture_atlas,
314+
texture_atlas_size=texture_atlas_size,
315+
texture_wrap=texture_wrap,
316+
path_manager=path_manager,
317+
)
318+
return mesh
319+
320+
def save(
321+
self,
322+
data: Meshes,
323+
path: Union[str, Path],
324+
path_manager: PathManager,
325+
binary: Optional[bool],
326+
decimal_places: Optional[int] = None,
327+
**kwargs,
328+
) -> bool:
329+
if not endswith(path, self.known_suffixes):
330+
return False
331+
332+
verts = data.verts_list()[0]
333+
faces = data.faces_list()[0]
334+
save_obj(
335+
f=path,
336+
verts=verts,
337+
faces=faces,
338+
decimal_places=decimal_places,
339+
path_manager=path_manager,
340+
)
341+
return True
342+
343+
289344
def _parse_face(
290345
line,
291346
tokens,

pytorch3d/io/pluggable.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from iopath.common.file_io import PathManager
1111
from pytorch3d.structures import Meshes, Pointclouds
1212

13+
from .obj_io import MeshObjFormat
1314
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
15+
from .ply_io import MeshPlyFormat
1416

1517

1618
"""
@@ -70,8 +72,8 @@ def __init__(
7072
self.register_default_formats()
7173

7274
def register_default_formats(self) -> None:
73-
# This will be populated in later diffs
74-
pass
75+
self.register_meshes_format(MeshObjFormat())
76+
self.register_meshes_format(MeshPlyFormat())
7577

7678
def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
7779
"""

pytorch3d/io/ply_io.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
import warnings
1010
from collections import namedtuple
1111
from io import BytesIO
12-
from typing import Optional, Tuple
12+
from pathlib import Path
13+
from typing import Optional, Tuple, Union
1314

1415
import numpy as np
1516
import torch
1617
from iopath.common.file_io import PathManager
1718
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
19+
from pytorch3d.structures import Meshes
20+
21+
from .pluggable_formats import MeshFormatInterpreter, endswith
1822

1923

2024
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
@@ -679,8 +683,7 @@ def load_ply(f, path_manager: Optional[PathManager] = None):
679683
# but we don't need to enforce this.
680684

681685
if not len(face):
682-
# pyre-fixme[28]: Unexpected keyword argument `size`.
683-
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
686+
faces = torch.zeros((0, 3), dtype=torch.int64)
684687
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
685688
if face.shape[1] < 3:
686689
raise ValueError("Faces must have at least 3 vertices.")
@@ -831,3 +834,48 @@ def save_ply(
831834
path_manager = PathManager()
832835
with _open_file(f, path_manager, "wb") as f:
833836
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)
837+
838+
839+
class MeshPlyFormat(MeshFormatInterpreter):
840+
def __init__(self):
841+
self.known_suffixes = (".ply",)
842+
843+
def read(
844+
self,
845+
path: Union[str, Path],
846+
include_textures: bool,
847+
device,
848+
path_manager: PathManager,
849+
**kwargs,
850+
) -> Optional[Meshes]:
851+
if not endswith(path, self.known_suffixes):
852+
return None
853+
854+
verts, faces = load_ply(f=path, path_manager=path_manager)
855+
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
856+
return mesh
857+
858+
def save(
859+
self,
860+
data: Meshes,
861+
path: Union[str, Path],
862+
path_manager: PathManager,
863+
binary: Optional[bool],
864+
decimal_places: Optional[int] = None,
865+
**kwargs,
866+
) -> bool:
867+
if not endswith(path, self.known_suffixes):
868+
return False
869+
870+
# TODO: normals are not saved. We only want to save them if they already exist.
871+
verts = data.verts_list()[0]
872+
faces = data.faces_list()[0]
873+
save_ply(
874+
f=path,
875+
verts=verts,
876+
faces=faces,
877+
ascii=binary is False,
878+
decimal_places=decimal_places,
879+
path_manager=path_manager,
880+
)
881+
return True

tests/test_obj_io.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import warnings
66
from io import StringIO
77
from pathlib import Path
8+
from tempfile import NamedTemporaryFile
89

910
import torch
1011
from common_testing import TestCaseMixin
1112
from iopath.common.file_io import PathManager
12-
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
13+
from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
1314
from pytorch3d.io.mtl_io import (
1415
_bilinear_interpolation_grid_sample,
1516
_bilinear_interpolation_vectorized,
@@ -145,6 +146,70 @@ def test_load_obj_complex(self):
145146
self.assertTrue(materials is None)
146147
self.assertTrue(tex_maps is None)
147148

149+
def test_load_obj_complex_pluggable(self):
150+
"""
151+
This won't work on Windows due to the behavior of NamedTemporaryFile
152+
"""
153+
obj_file = "\n".join(
154+
[
155+
"# this is a comment", # Comments should be ignored.
156+
"v 0.1 0.2 0.3",
157+
"v 0.2 0.3 0.4",
158+
"v 0.3 0.4 0.5",
159+
"v 0.4 0.5 0.6",
160+
"vn 0.000000 0.000000 -1.000000",
161+
"vn -1.000000 -0.000000 -0.000000",
162+
"vn -0.000000 -0.000000 1.000000", # Normals should not be ignored.
163+
"v 0.5 0.6 0.7",
164+
"vt 0.749279 0.501284 0.0", # Some files add 0.0 - ignore this.
165+
"vt 0.999110 0.501077",
166+
"vt 0.999455 0.750380",
167+
"f 1 2 3",
168+
"f 1 2 4 3 5", # Polygons should be split into triangles
169+
"f 2/1/2 3/1/2 4/2/2", # Texture/normals are loaded correctly.
170+
"f -1 -2 1", # Negative indexing counts from the end.
171+
]
172+
)
173+
io = IO()
174+
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
175+
f.write(obj_file)
176+
f.flush()
177+
mesh = io.load_mesh(f.name)
178+
mesh_from_path = io.load_mesh(Path(f.name))
179+
180+
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
181+
f.write(obj_file)
182+
f.flush()
183+
with self.assertRaisesRegex(ValueError, "Invalid file header."):
184+
io.load_mesh(f.name)
185+
186+
expected_verts = torch.tensor(
187+
[
188+
[0.1, 0.2, 0.3],
189+
[0.2, 0.3, 0.4],
190+
[0.3, 0.4, 0.5],
191+
[0.4, 0.5, 0.6],
192+
[0.5, 0.6, 0.7],
193+
],
194+
dtype=torch.float32,
195+
)
196+
expected_faces = torch.tensor(
197+
[
198+
[0, 1, 2], # First face
199+
[0, 1, 3], # Second face (polygon)
200+
[0, 3, 2], # Second face (polygon)
201+
[0, 2, 4], # Second face (polygon)
202+
[1, 2, 3], # Third face (normals / texture)
203+
[4, 3, 0], # Fourth face (negative indices)
204+
],
205+
dtype=torch.int64,
206+
)
207+
self.assertClose(mesh.verts_padded(), expected_verts[None])
208+
self.assertClose(mesh.faces_padded(), expected_faces[None])
209+
self.assertClose(mesh_from_path.verts_padded(), expected_verts[None])
210+
self.assertClose(mesh_from_path.faces_padded(), expected_faces[None])
211+
self.assertIsNone(mesh.textures)
212+
148213
def test_load_obj_normals_only(self):
149214
obj_file = "\n".join(
150215
[
@@ -588,8 +653,8 @@ def test_load_obj_mlt_no_image(self):
588653
expected_atlas = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32)
589654
expected_atlas = expected_atlas[None, None, None, :].expand(2, R, R, -1)
590655
self.assertTrue(torch.allclose(aux.texture_atlas, expected_atlas))
591-
self.assertEquals(len(aux.material_colors.keys()), 1)
592-
self.assertEquals(list(aux.material_colors.keys()), ["material_1"])
656+
self.assertEqual(len(aux.material_colors.keys()), 1)
657+
self.assertEqual(list(aux.material_colors.keys()), ["material_1"])
593658

594659
def test_load_obj_missing_texture(self):
595660
DATA_DIR = Path(__file__).resolve().parent / "data"

0 commit comments

Comments
 (0)