Skip to content

Commit b183dcb

Browse files
bottlerfacebook-github-bot
authored andcommitted
skeleton of pluggable IO
Summary: Unified interface for loading and saving meshes and pointclouds. Reviewed By: nikhilaravi Differential Revision: D25372968 fbshipit-source-id: 6fe57cc3704a89d81d13e959bee707b0c7b57d3b
1 parent 9fc661f commit b183dcb

File tree

5 files changed

+370
-1
lines changed

5 files changed

+370
-1
lines changed

docs/notes/io.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
---
2+
hide_title: true
3+
sidebar_label: File IO
4+
---
5+
6+
# File IO
7+
There is a flexible interface for loading and saving point clouds and meshes from different formats.
8+
9+
The main usage is via the `pytorch3d.io.IO` object, and its methods
10+
`load_mesh`, `save_mesh`, `load_point_cloud` and `save_point_cloud`.
11+
12+
For example, to load a mesh you might do
13+
```
14+
from pytorch3d.io import IO
15+
16+
device=torch.device("cuda:0")
17+
mesh = IO().load_mesh("mymesh.ply", device=device)
18+
```
19+
20+
and to save a pointcloud you might do
21+
```
22+
pcl = Pointclouds(...)
23+
IO().save_point_cloud(pcl, "output_poincloud.obj")
24+
```

pytorch3d/io/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
from .obj_io import load_obj, load_objs_as_meshes, save_obj
5+
from .pluggable import IO
56
from .ply_io import load_ply, save_ply
67

78

pytorch3d/io/pluggable.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
# This source code is licensed under the license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
6+
from collections import deque
7+
from pathlib import Path
8+
from typing import Deque, Optional, Union
9+
10+
from iopath.common.file_io import PathManager
11+
from pytorch3d.structures import Meshes, Pointclouds
12+
13+
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
14+
15+
16+
"""
17+
This module has the master functions for loading and saving data.
18+
19+
The main usage is via the IO object, and its methods
20+
`load_mesh`, `save_mesh`, `load_pointcloud` and `save_pointcloud`.
21+
22+
For example, to load a mesh you might do
23+
```
24+
from pytorch3d.io import IO
25+
26+
mesh = IO().load_mesh("mymesh.obj")
27+
```
28+
29+
and to save a point cloud you might do
30+
31+
```
32+
pcl = Pointclouds(...)
33+
IO().save_pointcloud(pcl, "output_poincloud.obj")
34+
```
35+
36+
"""
37+
38+
39+
class IO:
40+
"""
41+
This class is the interface to flexible loading and saving of meshes and point clouds.
42+
43+
In simple cases the user will just initialise an instance of this class as `IO()`
44+
and then use its load and save functions. The arguments of the initializer are not
45+
usually needed.
46+
47+
The user can add their own formats for saving and loading by passing their own objects
48+
to the register_* functions.
49+
50+
Args:
51+
include_default_formats: If False, the built-in file formats will not be available.
52+
Then only user-registered formats can be used.
53+
path_manager: Used to customise how paths given as strings are interpreted.
54+
"""
55+
56+
def __init__(
57+
self,
58+
include_default_formats: bool = True,
59+
path_manager: Optional[PathManager] = None,
60+
):
61+
if path_manager is None:
62+
self.path_manager = PathManager()
63+
else:
64+
self.path_manager = path_manager
65+
66+
self.mesh_interpreters: Deque[MeshFormatInterpreter] = deque()
67+
self.pointcloud_interpreters: Deque[PointcloudFormatInterpreter] = deque()
68+
69+
if include_default_formats:
70+
self.register_default_formats()
71+
72+
def register_default_formats(self) -> None:
73+
# This will be populated in later diffs
74+
pass
75+
76+
def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
77+
"""
78+
Register a new interpreter for a new mesh file format.
79+
80+
Args:
81+
interpreter: the new interpreter to use, which must be an instance
82+
of a class which inherits MeshFormatInterpreter.
83+
"""
84+
self.mesh_interpreters.appendleft(interpreter)
85+
86+
def register_pointcloud_format(
87+
self, interpreter: PointcloudFormatInterpreter
88+
) -> None:
89+
"""
90+
Register a new interpreter for a new point cloud file format.
91+
92+
Args:
93+
interpreter: the new interpreter to use, which must be an instance
94+
of a class which inherits PointcloudFormatInterpreter.
95+
"""
96+
self.pointcloud_interpreters.appendleft(interpreter)
97+
98+
def load_mesh(
99+
self,
100+
path: Union[str, Path],
101+
include_textures: bool = True,
102+
device="cpu",
103+
**kwargs,
104+
) -> Meshes:
105+
"""
106+
Attempt to load a mesh from the given file, using a registered format.
107+
Materials are not returned. If you have a .obj file with materials
108+
you might want to load them with the load_obj function instead.
109+
110+
Args:
111+
path: file to read
112+
include_textures: whether to try to load texture information
113+
device: device on which to leave the data.
114+
115+
Returns:
116+
new Meshes object containing one mesh.
117+
"""
118+
for mesh_interpreter in self.mesh_interpreters:
119+
mesh = mesh_interpreter.read(
120+
path,
121+
include_textures=include_textures,
122+
path_manager=self.path_manager,
123+
device=device,
124+
**kwargs,
125+
)
126+
if mesh is not None:
127+
return mesh
128+
129+
raise ValueError(f"No mesh interpreter found to read {path}.")
130+
131+
def save_mesh(
132+
self,
133+
data: Meshes,
134+
path: Union[str, Path],
135+
binary: Optional[bool] = None,
136+
include_textures: bool = True,
137+
**kwargs,
138+
) -> None:
139+
"""
140+
Attempt to save a mesh to the given file, using a registered format.
141+
142+
Args:
143+
data: a 1-element Meshes
144+
path: file to write
145+
binary: If there is a choice, whether to save in a binary format.
146+
include_textures: If textures are present, whether to try to save
147+
them.
148+
"""
149+
if len(data) != 1:
150+
raise ValueError("Can only save a single mesh.")
151+
152+
for mesh_interpreter in self.mesh_interpreters:
153+
success = mesh_interpreter.save(
154+
data, path, path_manager=self.path_manager, binary=binary, **kwargs
155+
)
156+
if success:
157+
return
158+
159+
raise ValueError(f"No mesh interpreter found to write to {path}.")
160+
161+
def load_pointcloud(
162+
self, path: Union[str, Path], device="cpu", **kwargs
163+
) -> Pointclouds:
164+
"""
165+
Attempt to load a point cloud from the given file, using a registered format.
166+
167+
Args:
168+
path: file to read
169+
device: torch.device on which to load the data.
170+
171+
Returns:
172+
new Pointclouds object containing one mesh.
173+
"""
174+
for pointcloud_interpreter in self.pointcloud_interpreters:
175+
pointcloud = pointcloud_interpreter.read(
176+
path, path_manager=self.path_manager, device=device, **kwargs
177+
)
178+
if pointcloud is not None:
179+
return pointcloud
180+
181+
raise ValueError(f"No point cloud interpreter found to read {path}.")
182+
183+
def save_pointcloud(
184+
self,
185+
data: Pointclouds,
186+
path: Union[str, Path],
187+
binary: Optional[bool] = None,
188+
**kwargs,
189+
) -> None:
190+
"""
191+
Attempt to save a point cloud to the given file, using a registered format.
192+
193+
Args:
194+
data: a 1-element Pointclouds
195+
path: file to write
196+
binary: If there is a choice, whether to save in a binary format.
197+
"""
198+
if len(data) != 1:
199+
raise ValueError("Can only save a single point cloud.")
200+
201+
for pointcloud_interpreter in self.pointcloud_interpreters:
202+
success = pointcloud_interpreter.save(
203+
data, path, path_manager=self.path_manager, binary=binary, **kwargs
204+
)
205+
if success:
206+
return
207+
208+
raise ValueError(f"No point cloud interpreter found to write to {path}.")

pytorch3d/io/pluggable_formats.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
# This source code is licensed under the license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
6+
from pathlib import Path
7+
from typing import Optional, Tuple, Union
8+
9+
from iopath.common.file_io import PathManager
10+
from pytorch3d.structures import Meshes, Pointclouds
11+
12+
13+
"""
14+
This module has the base classes which must be extended to define
15+
an interpreter for loading and saving data in a particular format.
16+
These can be registered on an IO object so that they can be used in
17+
its load_* and save_* functions.
18+
"""
19+
20+
21+
def endswith(path, suffixes: Tuple[str, ...]) -> bool:
22+
"""
23+
Returns whether the path ends with one of the given suffixes.
24+
If `path` is not actually a path, returns True. This is useful
25+
for allowing interpreters to bypass inappropriate paths, but
26+
always accepting streams.
27+
"""
28+
if isinstance(path, Path):
29+
return path.suffix.lower() in suffixes
30+
if isinstance(path, str):
31+
return path.lower().endswith(suffixes)
32+
return True
33+
34+
35+
class MeshFormatInterpreter:
36+
"""
37+
This is a base class for an interpreter which can read or write
38+
a mesh in a particular format.
39+
"""
40+
41+
def read(
42+
self,
43+
path: Union[str, Path],
44+
include_textures: bool,
45+
device,
46+
path_manager: PathManager,
47+
**kwargs,
48+
) -> Optional[Meshes]:
49+
"""
50+
Read the data from the specified file and return it as
51+
a Meshes object.
52+
53+
Args:
54+
path: path to load.
55+
include_textures: whether to try to load texture information.
56+
device: torch.device to load data on to.
57+
path_manager: PathManager to interpret the path.
58+
59+
Returns:
60+
None if self is not the appropriate object to interpret the given
61+
path.
62+
Otherwise, the read Meshes object.
63+
"""
64+
raise NotImplementedError()
65+
66+
def save(
67+
self,
68+
data: Meshes,
69+
path: Union[str, Path],
70+
path_manager: PathManager,
71+
binary: Optional[bool],
72+
**kwargs,
73+
) -> bool:
74+
"""
75+
Save the given Meshes object to the given path.
76+
77+
Args:
78+
data: mesh to save
79+
path: path to save to, which may be overwritten.
80+
path_manager: PathManager to interpret the path.
81+
binary: If there is a choice, whether to save in a binary format.
82+
83+
Returns:
84+
False: if self is not the appropriate object to write to the given path.
85+
True: on success.
86+
"""
87+
raise NotImplementedError()
88+
89+
90+
class PointcloudFormatInterpreter:
91+
"""
92+
This is a base class for an interpreter which can read or write
93+
a point cloud in a particular format.
94+
"""
95+
96+
def read(
97+
self, path: Union[str, Path], device, path_manager: PathManager, **kwargs
98+
) -> Optional[Pointclouds]:
99+
"""
100+
Read the data from the specified file and return it as
101+
a Pointclouds object.
102+
103+
Args:
104+
path: path to load.
105+
device: torch.device to load data on to.
106+
path_manager: PathManager to interpret the path.
107+
108+
Returns:
109+
None if self is not the appropriate object to interpret the given
110+
path.
111+
Otherwise, the read Pointclouds object.
112+
"""
113+
raise NotImplementedError()
114+
115+
def save(
116+
self,
117+
data: Pointclouds,
118+
path: Union[str, Path],
119+
path_manager: PathManager,
120+
binary: Optional[bool],
121+
**kwargs,
122+
) -> bool:
123+
"""
124+
Save the given Pointclouds object to the given path.
125+
126+
Args:
127+
data: point cloud object to save
128+
path: path to save to, which may be overwritten.
129+
path_manager: PathManager to interpret the path.
130+
binary: If there is a choice, whether to save in a binary format.
131+
132+
Returns:
133+
False: if self is not the appropriate object to write to the given path.
134+
True: on success.
135+
"""
136+
raise NotImplementedError()

pytorch3d/structures/pointclouds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def split(self, split_sizes: list):
777777
returned.
778778
779779
Returns:
780-
list[PointClouds].
780+
list[Pointclouds].
781781
"""
782782
if not all(isinstance(x, int) for x in split_sizes):
783783
raise ValueError("Value of split_sizes must be a list of integers.")

0 commit comments

Comments
 (0)