@@ -648,11 +648,12 @@ def _load_obj(
648
648
649
649
650
650
def save_obj (
651
- f ,
652
- verts ,
653
- faces ,
654
- decimal_places : Optional [int ] = None ,
655
- path_manager : Optional [PathManager ] = None ,
651
+ f ,
652
+ verts ,
653
+ faces ,
654
+ normals : Optional [torch .FloatTensor ] = None ,
655
+ decimal_places : Optional [int ] = None ,
656
+ path_manager : Optional [PathManager ] = None ,
656
657
):
657
658
"""
658
659
Save a mesh to an .obj file.
@@ -661,6 +662,7 @@ def save_obj(
661
662
f: File (or path) to which the mesh should be written.
662
663
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
663
664
faces: LongTensor of shape (F, 3) giving faces.
665
+ normals: FloatTensor of shape (N, 3) giving normals.
664
666
decimal_places: Number of decimal places for saving.
665
667
path_manager: Optional PathManager for interpreting f if
666
668
it is a str.
@@ -677,44 +679,50 @@ def save_obj(
677
679
path_manager = PathManager ()
678
680
679
681
with _open_file (f , path_manager , "w" ) as f :
680
- return _save (f , verts , faces , decimal_places )
682
+ return _save (f , verts , faces , normals , decimal_places )
681
683
682
684
683
- # TODO (nikhilar) Speed up this function.
684
- def _save (f , verts , faces , decimal_places : Optional [int ] = None ) -> None :
685
+ def _save (f , verts , faces , normals : Optional [torch .FloatTensor ] = None , decimal_places : Optional [int ] = None ) -> None :
685
686
assert not len (verts ) or (verts .dim () == 2 and verts .size (1 ) == 3 )
686
687
assert not len (faces ) or (faces .dim () == 2 and faces .size (1 ) == 3 )
688
+ if normals is not None :
689
+ assert not len (normals ) or (normals .dim () == 2 and normals .size (1 ) == 3 )
690
+ assert len (normals ) == len (faces ) or len (normals ) == len (verts )
687
691
688
692
if not (len (verts ) or len (faces )):
689
693
warnings .warn ("Empty 'verts' and 'faces' arguments provided" )
690
694
return
691
695
692
- verts , faces = verts .cpu (), faces .cpu ()
696
+ if torch .any (faces >= verts .shape [0 ]) or torch .any (faces < 0 ):
697
+ warnings .warn ("Faces have invalid indices" )
698
+
699
+ if decimal_places is None :
700
+ float_format = "{:f}" .format
701
+ else :
702
+ float_format = ("{:.%df}" % decimal_places ).format
703
+
704
+ float_format = np .vectorize (float_format )
705
+ verts , faces = verts .cpu ().numpy (), faces .cpu ().numpy ()
693
706
694
707
lines = ""
695
708
696
709
if len (verts ):
697
- if decimal_places is None :
698
- float_str = "%f"
699
- else :
700
- float_str = "%" + ".%df" % decimal_places
710
+ rows = np .apply_along_axis (' ' .join , 1 , float_format (verts ))
711
+ lines += 'v ' + '\n v ' .join (rows )
701
712
702
- V , D = verts .shape
703
- for i in range (V ):
704
- vert = [float_str % verts [i , j ] for j in range (D )]
705
- lines += "v %s\n " % " " .join (vert )
706
-
707
- if torch .any (faces >= verts .shape [0 ]) or torch .any (faces < 0 ):
708
- warnings .warn ("Faces have invalid indices" )
713
+ if normals is not None and len (normals ):
714
+ normals = normals .cpu ().numpy ()
715
+ rows = np .apply_along_axis (' ' .join , 1 , float_format (normals ))
716
+ lines += '\n vn ' + '\n vn ' .join (rows )
709
717
710
718
if len (faces ):
711
- F , P = faces . shape
712
- for i in range ( F ):
713
- face = [ "%d" % ( faces [ i , j ] + 1 ) for j in range ( P )]
714
- if i + 1 < F :
715
- lines += "f %s \n " % " " . join ( face )
716
- elif i + 1 == F :
717
- # No newline at the end of the file.
718
- lines += "f %s" % " " . join ( face )
719
-
719
+ # faces start indexing with 1 and not 0
720
+ faces = ( faces + 1 ). astype ( int ). astype ( str )
721
+ if normals is not None :
722
+ normals_format = np . vectorize ( "{0}//{0}" . format )
723
+ faces = normals_format ( faces )
724
+ rows = np . apply_along_axis ( ' ' . join , 1 , faces )
725
+ lines += ' \n f ' + ' \n f ' . join ( rows )
726
+
727
+ lines = lines . strip ()
720
728
f .write (lines )
0 commit comments