Skip to content

Commit 745aaf3

Browse files
patricklabatutfacebook-github-bot
authored andcommitted
No side effect with invalid inputs to save_obj / save_ply
Summary: Do not create output files with invalid inputs to `save_{obj,ply}()`. Reviewed By: bottler Differential Revision: D20720282 fbshipit-source-id: 3b611a10da6f6eecacab2a1900bf16f89e2000aa
1 parent 83feed5 commit 745aaf3

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

pytorch3d/io/obj_io.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
526526
faces: LongTensor of shape (F, 3) giving faces.
527527
decimal_places: Number of decimal places for saving.
528528
"""
529+
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
530+
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
531+
raise ValueError(message)
532+
533+
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
534+
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
535+
raise ValueError(message)
536+
529537
new_f = False
530538
if isinstance(f, str):
531539
new_f = True
@@ -541,21 +549,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
541549

542550

543551
# TODO (nikhilar) Speed up this function.
544-
def _save(f, verts, faces, decimal_places: Optional[int] = None):
552+
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
553+
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
554+
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
555+
545556
if not (len(verts) or len(faces)):
546557
warnings.warn("Empty 'verts' and 'faces' arguments provided")
547558
return
548559

549-
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
550-
raise ValueError(
551-
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
552-
)
553-
554-
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
555-
raise ValueError(
556-
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
557-
)
558-
559560
verts, faces = verts.cpu(), faces.cpu()
560561

561562
lines = ""

pytorch3d/io/ply_io.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def load_ply(f):
700700
return verts, faces
701701

702702

703-
def _save_ply(f, verts, faces, decimal_places: Optional[int]):
703+
def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
704704
"""
705705
Internal implementation for saving a mesh to a .ply file.
706706
@@ -710,15 +710,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]):
710710
faces: LongTensor of shape (F, 3) giving faces.
711711
decimal_places: Number of decimal places for saving.
712712
"""
713-
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
714-
raise ValueError(
715-
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
716-
)
717-
718-
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
719-
raise ValueError(
720-
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
721-
)
713+
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
714+
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
722715

723716
print("ply\nformat ascii 1.0", file=f)
724717
print(f"element vertex {verts.shape[0]}", file=f)
@@ -760,6 +753,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
760753
faces: LongTensor of shape (F, 3) giving faces.
761754
decimal_places: Number of decimal places for saving.
762755
"""
756+
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
757+
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
758+
raise ValueError(message)
759+
760+
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
761+
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
762+
raise ValueError(message)
763+
763764
new_f = False
764765
if isinstance(f, str):
765766
new_f = True

0 commit comments

Comments
 (0)