15
15
import numpy as np
16
16
import torch
17
17
from iopath .common .file_io import PathManager
18
+ from PIL import Image
18
19
from pytorch3d .common .types import Device
19
20
from pytorch3d .io .mtl_io import load_mtl , make_mesh_texture_atlas
20
21
from pytorch3d .io .utils import _check_faces_indices , _make_tensor , _open_file
@@ -649,42 +650,118 @@ def _load_obj(
649
650
650
651
651
652
def save_obj (
652
- f ,
653
+ f : Union [ str , os . PathLike ] ,
653
654
verts ,
654
655
faces ,
655
656
decimal_places : Optional [int ] = None ,
656
657
path_manager : Optional [PathManager ] = None ,
657
- ):
658
+ * ,
659
+ verts_uvs : Optional [torch .Tensor ] = None ,
660
+ faces_uvs : Optional [torch .Tensor ] = None ,
661
+ texture_map : Optional [torch .Tensor ] = None ,
662
+ ) -> None :
658
663
"""
659
664
Save a mesh to an .obj file.
660
665
661
666
Args:
662
- f: File (or path) to which the mesh should be written.
667
+ f: File (str or path) to which the mesh should be written.
663
668
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
664
669
faces: LongTensor of shape (F, 3) giving faces.
665
670
decimal_places: Number of decimal places for saving.
666
671
path_manager: Optional PathManager for interpreting f if
667
672
it is a str.
673
+ verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
674
+ faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
675
+ each vertex in the face.
676
+ texture_map: FloatTensor of shape (H, W, 3) representing the texture map
677
+ for the mesh which will be saved as an image. The values are expected
678
+ to be in the range [0, 1],
668
679
"""
669
- if len (verts ) and not (verts .dim () == 2 and verts .size (1 ) == 3 ):
670
- message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
680
+ if len (verts ) and (verts .dim () != 2 or verts .size (1 ) != 3 ):
681
+ message = "'verts' should either be empty or of shape (num_verts, 3)."
682
+ raise ValueError (message )
683
+
684
+ if len (faces ) and (faces .dim () != 2 or faces .size (1 ) != 3 ):
685
+ message = "'faces' should either be empty or of shape (num_faces, 3)."
686
+ raise ValueError (message )
687
+
688
+ if faces_uvs is not None and (faces_uvs .dim () != 2 or faces_uvs .size (1 ) != 3 ):
689
+ message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
671
690
raise ValueError (message )
672
691
673
- if len (faces ) and not (faces .dim () == 2 and faces .size (1 ) == 3 ):
674
- message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
692
+ if verts_uvs is not None and (verts_uvs .dim () != 2 or verts_uvs .size (1 ) != 2 ):
693
+ message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
694
+ raise ValueError (message )
695
+
696
+ if texture_map is not None and (texture_map .dim () != 3 or texture_map .size (2 ) != 3 ):
697
+ message = "'texture_map' should either be empty or of shape (H, W, 3)."
675
698
raise ValueError (message )
676
699
677
700
if path_manager is None :
678
701
path_manager = PathManager ()
679
702
703
+ save_texture = all ([t is not None for t in [faces_uvs , verts_uvs , texture_map ]])
704
+ output_path = Path (f )
705
+
706
+ # Save the .obj file
680
707
with _open_file (f , path_manager , "w" ) as f :
681
- return _save (f , verts , faces , decimal_places )
708
+ if save_texture :
709
+ # Add the header required for the texture info to be loaded correctly
710
+ obj_header = "\n mtllib {0}.mtl\n usemtl mesh\n \n " .format (output_path .stem )
711
+ f .write (obj_header )
712
+ _save (
713
+ f ,
714
+ verts ,
715
+ faces ,
716
+ decimal_places ,
717
+ verts_uvs = verts_uvs ,
718
+ faces_uvs = faces_uvs ,
719
+ save_texture = save_texture ,
720
+ )
721
+
722
+ # Save the .mtl and .png files associated with the texture
723
+ if save_texture :
724
+ image_path = output_path .with_suffix (".png" )
725
+ mtl_path = output_path .with_suffix (".mtl" )
726
+ if isinstance (f , str ):
727
+ # Back to str for iopath interpretation.
728
+ image_path = str (image_path )
729
+ mtl_path = str (mtl_path )
730
+
731
+ # Save texture map to output folder
732
+ # pyre-fixme[16] # undefined attribute cpu
733
+ texture_map = texture_map .detach ().cpu () * 255.0
734
+ image = Image .fromarray (texture_map .numpy ().astype (np .uint8 ))
735
+ with _open_file (image_path , path_manager , "wb" ) as im_f :
736
+ # pyre-fixme[6] # incompatible parameter type
737
+ image .save (im_f )
738
+
739
+ # Create .mtl file with the material name and texture map filename
740
+ # TODO: enable material properties to also be saved.
741
+ with _open_file (mtl_path , path_manager , "w" ) as f_mtl :
742
+ lines = f"newmtl mesh\n " f"map_Kd { output_path .stem } .png\n "
743
+ f_mtl .write (lines )
682
744
683
745
684
746
# TODO (nikhilar) Speed up this function.
685
- def _save (f , verts , faces , decimal_places : Optional [int ] = None ) -> None :
686
- assert not len (verts ) or (verts .dim () == 2 and verts .size (1 ) == 3 )
687
- assert not len (faces ) or (faces .dim () == 2 and faces .size (1 ) == 3 )
747
+ def _save (
748
+ f ,
749
+ verts ,
750
+ faces ,
751
+ decimal_places : Optional [int ] = None ,
752
+ * ,
753
+ verts_uvs : Optional [torch .Tensor ] = None ,
754
+ faces_uvs : Optional [torch .Tensor ] = None ,
755
+ save_texture : bool = False ,
756
+ ) -> None :
757
+
758
+ if len (verts ) and (verts .dim () != 2 or verts .size (1 ) != 3 ):
759
+ message = "'verts' should either be empty or of shape (num_verts, 3)."
760
+ raise ValueError (message )
761
+
762
+ if len (faces ) and (faces .dim () != 2 or faces .size (1 ) != 3 ):
763
+ message = "'faces' should either be empty or of shape (num_faces, 3)."
764
+ raise ValueError (message )
688
765
689
766
if not (len (verts ) or len (faces )):
690
767
warnings .warn ("Empty 'verts' and 'faces' arguments provided" )
@@ -705,15 +782,42 @@ def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
705
782
vert = [float_str % verts [i , j ] for j in range (D )]
706
783
lines += "v %s\n " % " " .join (vert )
707
784
785
+ if save_texture :
786
+ if faces_uvs is not None and (faces_uvs .dim () != 2 or faces_uvs .size (1 ) != 3 ):
787
+ message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
788
+ raise ValueError (message )
789
+
790
+ if verts_uvs is not None and (verts_uvs .dim () != 2 or verts_uvs .size (1 ) != 2 ):
791
+ message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
792
+ raise ValueError (message )
793
+
794
+ # pyre-fixme[16] # undefined attribute cpu
795
+ verts_uvs , faces_uvs = verts_uvs .cpu (), faces_uvs .cpu ()
796
+
797
+ # Save verts uvs after verts
798
+ if len (verts_uvs ):
799
+ uV , uD = verts_uvs .shape
800
+ for i in range (uV ):
801
+ uv = [float_str % verts_uvs [i , j ] for j in range (uD )]
802
+ lines += "vt %s\n " % " " .join (uv )
803
+
708
804
if torch .any (faces >= verts .shape [0 ]) or torch .any (faces < 0 ):
709
805
warnings .warn ("Faces have invalid indices" )
710
806
711
807
if len (faces ):
712
808
F , P = faces .shape
713
809
for i in range (F ):
714
- face = ["%d" % (faces [i , j ] + 1 ) for j in range (P )]
810
+ if save_texture :
811
+ # Format faces as {verts_idx}/{verts_uvs_idx}
812
+ face = [
813
+ "%d/%d" % (faces [i , j ] + 1 , faces_uvs [i , j ] + 1 ) for j in range (P )
814
+ ]
815
+ else :
816
+ face = ["%d" % (faces [i , j ] + 1 ) for j in range (P )]
817
+
715
818
if i + 1 < F :
716
819
lines += "f %s\n " % " " .join (face )
820
+
717
821
elif i + 1 == F :
718
822
# No newline at the end of the file.
719
823
lines += "f %s" % " " .join (face )
0 commit comments