@@ -648,11 +648,12 @@ def _load_obj(
648
648
649
649
650
650
def save_obj (
651
- f ,
652
- verts ,
653
- faces ,
654
- decimal_places : Optional [int ] = None ,
655
- path_manager : Optional [PathManager ] = None ,
651
+ f ,
652
+ verts ,
653
+ faces ,
654
+ normals : Optional [torch .FloatTensor ] = None ,
655
+ decimal_places : Optional [int ] = None ,
656
+ path_manager : Optional [PathManager ] = None ,
656
657
):
657
658
"""
658
659
Save a mesh to an .obj file.
@@ -661,6 +662,7 @@ def save_obj(
661
662
f: File (or path) to which the mesh should be written.
662
663
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
663
664
faces: LongTensor of shape (F, 3) giving faces.
665
+ normals: FloatTensor of shape (N, 3) giving normals.
664
666
decimal_places: Number of decimal places for saving.
665
667
path_manager: Optional PathManager for interpreting f if
666
668
it is a str.
@@ -677,44 +679,56 @@ def save_obj(
677
679
path_manager = PathManager ()
678
680
679
681
with _open_file (f , path_manager , "w" ) as f :
680
- return _save (f , verts , faces , decimal_places )
682
+ return _save (f , verts , faces , normals , decimal_places )
681
683
682
684
683
- # TODO (nikhilar) Speed up this function.
684
- def _save (f , verts , faces , decimal_places : Optional [int ] = None ) -> None :
685
+ def _save (
686
+ f ,
687
+ verts ,
688
+ faces ,
689
+ normals : Optional [torch .FloatTensor ] = None ,
690
+ decimal_places : Optional [int ] = None ,
691
+ ) -> None :
685
692
assert not len (verts ) or (verts .dim () == 2 and verts .size (1 ) == 3 )
686
693
assert not len (faces ) or (faces .dim () == 2 and faces .size (1 ) == 3 )
694
+ if normals is not None :
695
+ assert not len (normals ) or (normals .dim () == 2 and normals .size (1 ) == 3 )
696
+ assert len (normals ) == len (faces ) or len (normals ) == len (verts )
687
697
688
698
if not (len (verts ) or len (faces )):
689
699
warnings .warn ("Empty 'verts' and 'faces' arguments provided" )
690
700
return
691
701
692
- verts , faces = verts .cpu (), faces .cpu ()
702
+ if torch .any (faces >= verts .shape [0 ]) or torch .any (faces < 0 ):
703
+ warnings .warn ("Faces have invalid indices" )
704
+
705
+ if decimal_places is None :
706
+ float_format = "{:f}" .format
707
+ else :
708
+ float_format = ("{:.%df}" % decimal_places ).format
709
+
710
+ float_format = np .vectorize (float_format )
711
+ verts , faces = verts .cpu ().numpy (), faces .cpu ().numpy ()
693
712
694
713
lines = ""
695
714
696
715
if len (verts ):
697
- if decimal_places is None :
698
- float_str = "%f"
699
- else :
700
- float_str = "%" + ".%df" % decimal_places
716
+ rows = np .apply_along_axis (" " .join , 1 , float_format (verts ))
717
+ lines += "v " + "\n v " .join (rows )
701
718
702
- V , D = verts .shape
703
- for i in range (V ):
704
- vert = [float_str % verts [i , j ] for j in range (D )]
705
- lines += "v %s\n " % " " .join (vert )
706
-
707
- if torch .any (faces >= verts .shape [0 ]) or torch .any (faces < 0 ):
708
- warnings .warn ("Faces have invalid indices" )
719
+ if normals is not None and len (normals ):
720
+ normals = normals .cpu ().numpy ()
721
+ rows = np .apply_along_axis (" " .join , 1 , float_format (normals ))
722
+ lines += "\n vn " + "\n vn " .join (rows )
709
723
710
724
if len (faces ):
711
- F , P = faces . shape
712
- for i in range ( F ):
713
- face = [ "%d" % ( faces [ i , j ] + 1 ) for j in range ( P )]
714
- if i + 1 < F :
715
- lines += "f %s \n " % " " . join ( face )
716
- elif i + 1 == F :
717
- # No newline at the end of the file.
718
- lines += "f %s" % " " . join ( face )
719
-
725
+ # faces start indexing with 1 and not 0
726
+ faces = ( faces + 1 ). astype ( int ). astype ( str )
727
+ if normals is not None :
728
+ normals_format = np . vectorize ( "{0}//{0}" . format )
729
+ faces = normals_format ( faces )
730
+ rows = np . apply_along_axis ( " " . join , 1 , faces )
731
+ lines += " \n f " + " \n f " . join ( rows )
732
+
733
+ lines = lines . strip ()
720
734
f .write (lines )
0 commit comments