Skip to content

Commit 33c7a79

Browse files
committed
Speed up save_obj function, add option of saving normals.
1 parent d17b121 commit 33c7a79

File tree

2 files changed

+76
-29
lines changed

2 files changed

+76
-29
lines changed

pytorch3d/io/obj_io.py

+43-29
Original file line numberDiff line numberDiff line change
@@ -648,11 +648,12 @@ def _load_obj(
648648

649649

650650
def save_obj(
651-
f,
652-
verts,
653-
faces,
654-
decimal_places: Optional[int] = None,
655-
path_manager: Optional[PathManager] = None,
651+
f,
652+
verts,
653+
faces,
654+
normals: Optional[torch.FloatTensor] = None,
655+
decimal_places: Optional[int] = None,
656+
path_manager: Optional[PathManager] = None,
656657
):
657658
"""
658659
Save a mesh to an .obj file.
@@ -661,6 +662,7 @@ def save_obj(
661662
f: File (or path) to which the mesh should be written.
662663
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
663664
faces: LongTensor of shape (F, 3) giving faces.
665+
normals: FloatTensor of shape (N, 3) giving normals.
664666
decimal_places: Number of decimal places for saving.
665667
path_manager: Optional PathManager for interpreting f if
666668
it is a str.
@@ -677,44 +679,56 @@ def save_obj(
677679
path_manager = PathManager()
678680

679681
with _open_file(f, path_manager, "w") as f:
680-
return _save(f, verts, faces, decimal_places)
682+
return _save(f, verts, faces, normals, decimal_places)
681683

682684

683-
# TODO (nikhilar) Speed up this function.
684-
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
685+
def _save(
686+
f,
687+
verts,
688+
faces,
689+
normals: Optional[torch.FloatTensor] = None,
690+
decimal_places: Optional[int] = None,
691+
) -> None:
685692
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
686693
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
694+
if normals is not None:
695+
assert not len(normals) or (normals.dim() == 2 and normals.size(1) == 3)
696+
assert len(normals) == len(faces) or len(normals) == len(verts)
687697

688698
if not (len(verts) or len(faces)):
689699
warnings.warn("Empty 'verts' and 'faces' arguments provided")
690700
return
691701

692-
verts, faces = verts.cpu(), faces.cpu()
702+
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
703+
warnings.warn("Faces have invalid indices")
704+
705+
if decimal_places is None:
706+
float_format = "{:f}".format
707+
else:
708+
float_format = ("{:.%df}" % decimal_places).format
709+
710+
float_format = np.vectorize(float_format)
711+
verts, faces = verts.cpu().numpy(), faces.cpu().numpy()
693712

694713
lines = ""
695714

696715
if len(verts):
697-
if decimal_places is None:
698-
float_str = "%f"
699-
else:
700-
float_str = "%" + ".%df" % decimal_places
716+
rows = np.apply_along_axis(" ".join, 1, float_format(verts))
717+
lines += "v " + "\nv ".join(rows)
701718

702-
V, D = verts.shape
703-
for i in range(V):
704-
vert = [float_str % verts[i, j] for j in range(D)]
705-
lines += "v %s\n" % " ".join(vert)
706-
707-
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
708-
warnings.warn("Faces have invalid indices")
719+
if normals is not None and len(normals):
720+
normals = normals.cpu().numpy()
721+
rows = np.apply_along_axis(" ".join, 1, float_format(normals))
722+
lines += "\nvn " + "\nvn ".join(rows)
709723

710724
if len(faces):
711-
F, P = faces.shape
712-
for i in range(F):
713-
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
714-
if i + 1 < F:
715-
lines += "f %s\n" % " ".join(face)
716-
elif i + 1 == F:
717-
# No newline at the end of the file.
718-
lines += "f %s" % " ".join(face)
719-
725+
# faces start indexing with 1 and not 0
726+
faces = (faces + 1).astype(int).astype(str)
727+
if normals is not None:
728+
normals_format = np.vectorize("{0}//{0}".format)
729+
faces = normals_format(faces)
730+
rows = np.apply_along_axis(" ".join, 1, faces)
731+
lines += "\nf " + "\nf ".join(rows)
732+
733+
lines = lines.strip()
720734
f.write(lines)

tests/test_io_obj.py

+33
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,39 @@ def test_save_obj(self):
480480
actual_file = obj_file.getvalue()
481481
self.assertEqual(actual_file, expected_file)
482482

483+
def test_save_obj_with_normals(self):
484+
verts = torch.tensor(
485+
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
486+
dtype=torch.float32,
487+
)
488+
faces = torch.tensor(
489+
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
490+
)
491+
normals = torch.tensor(
492+
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
493+
dtype=torch.float32,
494+
)
495+
obj_file = StringIO()
496+
save_obj(obj_file, verts, faces, normals, decimal_places=2)
497+
expected_file = "\n".join(
498+
[
499+
"v 0.01 0.20 0.30",
500+
"v 0.20 0.03 0.41",
501+
"v 0.30 0.40 0.05",
502+
"v 0.60 0.70 0.80",
503+
"vn 0.01 0.20 0.30",
504+
"vn 0.20 0.03 0.41",
505+
"vn 0.30 0.40 0.05",
506+
"vn 0.60 0.70 0.80",
507+
"f 1//1 3//3 2//2",
508+
"f 1//1 2//2 3//3",
509+
"f 4//4 3//3 2//2",
510+
"f 4//4 2//2 1//1",
511+
]
512+
)
513+
actual_file = obj_file.getvalue()
514+
self.assertEqual(actual_file, expected_file)
515+
483516
def test_load_mtl(self):
484517
obj_filename = "cow_mesh/cow.obj"
485518
filename = os.path.join(TUTORIAL_DATA_DIR, obj_filename)

0 commit comments

Comments
 (0)