Skip to content

Commit 542e2e7

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Save UV texture with obj mesh
Summary: Add functionality to to save an `.obj` file with associated UV textures: `.png` image and `.mtl` file as well as saving verts_uvs and faces_uvs to the `.obj` file. Reviewed By: bottler Differential Revision: D29337562 fbshipit-source-id: 86829b40dae9224088b328e7f5a16eacf8582eb5
1 parent 64289a4 commit 542e2e7

File tree

3 files changed

+549
-233
lines changed

3 files changed

+549
-233
lines changed

pytorch3d/io/obj_io.py

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import torch
1717
from iopath.common.file_io import PathManager
18+
from PIL import Image
1819
from pytorch3d.common.types import Device
1920
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
2021
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
@@ -649,42 +650,118 @@ def _load_obj(
649650

650651

651652
def save_obj(
652-
f,
653+
f: Union[str, os.PathLike],
653654
verts,
654655
faces,
655656
decimal_places: Optional[int] = None,
656657
path_manager: Optional[PathManager] = None,
657-
):
658+
*,
659+
verts_uvs: Optional[torch.Tensor] = None,
660+
faces_uvs: Optional[torch.Tensor] = None,
661+
texture_map: Optional[torch.Tensor] = None,
662+
) -> None:
658663
"""
659664
Save a mesh to an .obj file.
660665
661666
Args:
662-
f: File (or path) to which the mesh should be written.
667+
f: File (str or path) to which the mesh should be written.
663668
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
664669
faces: LongTensor of shape (F, 3) giving faces.
665670
decimal_places: Number of decimal places for saving.
666671
path_manager: Optional PathManager for interpreting f if
667672
it is a str.
673+
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
674+
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
675+
each vertex in the face.
676+
texture_map: FloatTensor of shape (H, W, 3) representing the texture map
677+
for the mesh which will be saved as an image. The values are expected
678+
to be in the range [0, 1],
668679
"""
669-
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
670-
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
680+
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
681+
message = "'verts' should either be empty or of shape (num_verts, 3)."
682+
raise ValueError(message)
683+
684+
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
685+
message = "'faces' should either be empty or of shape (num_faces, 3)."
686+
raise ValueError(message)
687+
688+
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
689+
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
671690
raise ValueError(message)
672691

673-
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
674-
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
692+
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
693+
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
694+
raise ValueError(message)
695+
696+
if texture_map is not None and (texture_map.dim() != 3 or texture_map.size(2) != 3):
697+
message = "'texture_map' should either be empty or of shape (H, W, 3)."
675698
raise ValueError(message)
676699

677700
if path_manager is None:
678701
path_manager = PathManager()
679702

703+
save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]])
704+
output_path = Path(f)
705+
706+
# Save the .obj file
680707
with _open_file(f, path_manager, "w") as f:
681-
return _save(f, verts, faces, decimal_places)
708+
if save_texture:
709+
# Add the header required for the texture info to be loaded correctly
710+
obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem)
711+
f.write(obj_header)
712+
_save(
713+
f,
714+
verts,
715+
faces,
716+
decimal_places,
717+
verts_uvs=verts_uvs,
718+
faces_uvs=faces_uvs,
719+
save_texture=save_texture,
720+
)
721+
722+
# Save the .mtl and .png files associated with the texture
723+
if save_texture:
724+
image_path = output_path.with_suffix(".png")
725+
mtl_path = output_path.with_suffix(".mtl")
726+
if isinstance(f, str):
727+
# Back to str for iopath interpretation.
728+
image_path = str(image_path)
729+
mtl_path = str(mtl_path)
730+
731+
# Save texture map to output folder
732+
# pyre-fixme[16] # undefined attribute cpu
733+
texture_map = texture_map.detach().cpu() * 255.0
734+
image = Image.fromarray(texture_map.numpy().astype(np.uint8))
735+
with _open_file(image_path, path_manager, "wb") as im_f:
736+
# pyre-fixme[6] # incompatible parameter type
737+
image.save(im_f)
738+
739+
# Create .mtl file with the material name and texture map filename
740+
# TODO: enable material properties to also be saved.
741+
with _open_file(mtl_path, path_manager, "w") as f_mtl:
742+
lines = f"newmtl mesh\n" f"map_Kd {output_path.stem}.png\n"
743+
f_mtl.write(lines)
682744

683745

684746
# TODO (nikhilar) Speed up this function.
685-
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
686-
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
687-
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
747+
def _save(
748+
f,
749+
verts,
750+
faces,
751+
decimal_places: Optional[int] = None,
752+
*,
753+
verts_uvs: Optional[torch.Tensor] = None,
754+
faces_uvs: Optional[torch.Tensor] = None,
755+
save_texture: bool = False,
756+
) -> None:
757+
758+
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
759+
message = "'verts' should either be empty or of shape (num_verts, 3)."
760+
raise ValueError(message)
761+
762+
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
763+
message = "'faces' should either be empty or of shape (num_faces, 3)."
764+
raise ValueError(message)
688765

689766
if not (len(verts) or len(faces)):
690767
warnings.warn("Empty 'verts' and 'faces' arguments provided")
@@ -705,15 +782,42 @@ def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
705782
vert = [float_str % verts[i, j] for j in range(D)]
706783
lines += "v %s\n" % " ".join(vert)
707784

785+
if save_texture:
786+
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
787+
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
788+
raise ValueError(message)
789+
790+
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
791+
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
792+
raise ValueError(message)
793+
794+
# pyre-fixme[16] # undefined attribute cpu
795+
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
796+
797+
# Save verts uvs after verts
798+
if len(verts_uvs):
799+
uV, uD = verts_uvs.shape
800+
for i in range(uV):
801+
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
802+
lines += "vt %s\n" % " ".join(uv)
803+
708804
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
709805
warnings.warn("Faces have invalid indices")
710806

711807
if len(faces):
712808
F, P = faces.shape
713809
for i in range(F):
714-
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
810+
if save_texture:
811+
# Format faces as {verts_idx}/{verts_uvs_idx}
812+
face = [
813+
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
814+
]
815+
else:
816+
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
817+
715818
if i + 1 < F:
716819
lines += "f %s\n" % " ".join(face)
820+
717821
elif i + 1 == F:
718822
# No newline at the end of the file.
719823
lines += "f %s" % " ".join(face)

tests/common_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_pytorch3d_dir() -> Path:
3434

3535

3636
def load_rgb_image(filename: str, data_dir: Union[str, Path]):
37-
filepath = data_dir / filename
37+
filepath = os.path.join(data_dir, filename)
3838
with Image.open(filepath) as raw_image:
3939
image = torch.from_numpy(np.array(raw_image) / 255.0)
4040
image = image.to(dtype=torch.float32)

0 commit comments

Comments
 (0)