Skip to content

Speed up save_obj function, add option of saving normals. #667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions pytorch3d/io/obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,11 +648,12 @@ def _load_obj(


def save_obj(
f,
verts,
faces,
decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
f,
verts,
faces,
normals: Optional[torch.FloatTensor] = None,
decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
):
"""
Save a mesh to an .obj file.
Expand All @@ -661,6 +662,7 @@ def save_obj(
f: File (or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
normals: FloatTensor of shape (N, 3) giving normals.
decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if
it is a str.
Expand All @@ -677,44 +679,56 @@ def save_obj(
path_manager = PathManager()

with _open_file(f, path_manager, "w") as f:
return _save(f, verts, faces, decimal_places)
return _save(f, verts, faces, normals, decimal_places)


# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
def _save(
f,
verts,
faces,
normals: Optional[torch.FloatTensor] = None,
decimal_places: Optional[int] = None,
) -> None:
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
if normals is not None:
assert not len(normals) or (normals.dim() == 2 and normals.size(1) == 3)
assert len(normals) == len(faces) or len(normals) == len(verts)

if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
return

verts, faces = verts.cpu(), faces.cpu()
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")

if decimal_places is None:
float_format = "{:f}".format
else:
float_format = ("{:.%df}" % decimal_places).format

float_format = np.vectorize(float_format)
verts, faces = verts.cpu().numpy(), faces.cpu().numpy()

lines = ""

if len(verts):
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
rows = np.apply_along_axis(" ".join, 1, float_format(verts))
lines += "v " + "\nv ".join(rows)

V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)

if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")
if normals is not None and len(normals):
normals = normals.cpu().numpy()
rows = np.apply_along_axis(" ".join, 1, float_format(normals))
lines += "\nvn " + "\nvn ".join(rows)

if len(faces):
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)

# faces start indexing with 1 and not 0
faces = (faces + 1).astype(int).astype(str)
if normals is not None:
normals_format = np.vectorize("{0}//{0}".format)
faces = normals_format(faces)
rows = np.apply_along_axis(" ".join, 1, faces)
lines += "\nf " + "\nf ".join(rows)

lines = lines.strip()
f.write(lines)
33 changes: 33 additions & 0 deletions tests/test_io_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,39 @@ def test_save_obj(self):
actual_file = obj_file.getvalue()
self.assertEqual(actual_file, expected_file)

def test_save_obj_with_normals(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
normals = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
obj_file = StringIO()
save_obj(obj_file, verts, faces, normals, decimal_places=2)
expected_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vn 0.01 0.20 0.30",
"vn 0.20 0.03 0.41",
"vn 0.30 0.40 0.05",
"vn 0.60 0.70 0.80",
"f 1//1 3//3 2//2",
"f 1//1 2//2 3//3",
"f 4//4 3//3 2//2",
"f 4//4 2//2 1//1",
]
)
actual_file = obj_file.getvalue()
self.assertEqual(actual_file, expected_file)

def test_load_mtl(self):
obj_filename = "cow_mesh/cow.obj"
filename = os.path.join(TUTORIAL_DATA_DIR, obj_filename)
Expand Down