diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index bb329fb4f..f285d7b33 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -648,11 +648,12 @@ def _load_obj( def save_obj( - f, - verts, - faces, - decimal_places: Optional[int] = None, - path_manager: Optional[PathManager] = None, + f, + verts, + faces, + normals: Optional[torch.FloatTensor] = None, + decimal_places: Optional[int] = None, + path_manager: Optional[PathManager] = None, ): """ Save a mesh to an .obj file. @@ -661,6 +662,7 @@ def save_obj( f: File (or path) to which the mesh should be written. verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. + normals: FloatTensor of shape (N, 3) giving normals. decimal_places: Number of decimal places for saving. path_manager: Optional PathManager for interpreting f if it is a str. @@ -677,44 +679,56 @@ def save_obj( path_manager = PathManager() with _open_file(f, path_manager, "w") as f: - return _save(f, verts, faces, decimal_places) + return _save(f, verts, faces, normals, decimal_places) -# TODO (nikhilar) Speed up this function. -def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None: +def _save( + f, + verts, + faces, + normals: Optional[torch.FloatTensor] = None, + decimal_places: Optional[int] = None, +) -> None: assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) + if normals is not None: + assert not len(normals) or (normals.dim() == 2 and normals.size(1) == 3) + assert len(normals) == len(faces) or len(normals) == len(verts) if not (len(verts) or len(faces)): warnings.warn("Empty 'verts' and 'faces' arguments provided") return - verts, faces = verts.cpu(), faces.cpu() + if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): + warnings.warn("Faces have invalid indices") + + if decimal_places is None: + float_format = "{:f}".format + else: + float_format = ("{:.%df}" % decimal_places).format + + float_format = np.vectorize(float_format) + verts, faces = verts.cpu().numpy(), faces.cpu().numpy() lines = "" if len(verts): - if decimal_places is None: - float_str = "%f" - else: - float_str = "%" + ".%df" % decimal_places + rows = np.apply_along_axis(" ".join, 1, float_format(verts)) + lines += "v " + "\nv ".join(rows) - V, D = verts.shape - for i in range(V): - vert = [float_str % verts[i, j] for j in range(D)] - lines += "v %s\n" % " ".join(vert) - - if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): - warnings.warn("Faces have invalid indices") + if normals is not None and len(normals): + normals = normals.cpu().numpy() + rows = np.apply_along_axis(" ".join, 1, float_format(normals)) + lines += "\nvn " + "\nvn ".join(rows) if len(faces): - F, P = faces.shape - for i in range(F): - face = ["%d" % (faces[i, j] + 1) for j in range(P)] - if i + 1 < F: - lines += "f %s\n" % " ".join(face) - elif i + 1 == F: - # No newline at the end of the file. - lines += "f %s" % " ".join(face) - + # faces start indexing with 1 and not 0 + faces = (faces + 1).astype(int).astype(str) + if normals is not None: + normals_format = np.vectorize("{0}//{0}".format) + faces = normals_format(faces) + rows = np.apply_along_axis(" ".join, 1, faces) + lines += "\nf " + "\nf ".join(rows) + + lines = lines.strip() f.write(lines) diff --git a/tests/test_io_obj.py b/tests/test_io_obj.py index 34b488b2e..13a795f32 100644 --- a/tests/test_io_obj.py +++ b/tests/test_io_obj.py @@ -480,6 +480,39 @@ def test_save_obj(self): actual_file = obj_file.getvalue() self.assertEqual(actual_file, expected_file) + def test_save_obj_with_normals(self): + verts = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + faces = torch.tensor( + [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 + ) + normals = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + obj_file = StringIO() + save_obj(obj_file, verts, faces, normals, decimal_places=2) + expected_file = "\n".join( + [ + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "vn 0.01 0.20 0.30", + "vn 0.20 0.03 0.41", + "vn 0.30 0.40 0.05", + "vn 0.60 0.70 0.80", + "f 1//1 3//3 2//2", + "f 1//1 2//2 3//3", + "f 4//4 3//3 2//2", + "f 4//4 2//2 1//1", + ] + ) + actual_file = obj_file.getvalue() + self.assertEqual(actual_file, expected_file) + def test_load_mtl(self): obj_filename = "cow_mesh/cow.obj" filename = os.path.join(TUTORIAL_DATA_DIR, obj_filename)