Skip to content

Commit b7c826b

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Boolean indexing of cameras
Summary: Reasonable to expect bool indexing. Reviewed By: bottler, kjchalup Differential Revision: D38741446 fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
1 parent 6080897 commit b7c826b

File tree

6 files changed

+58
-18
lines changed

6 files changed

+58
-18
lines changed

pytorch3d/renderer/cameras.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -385,31 +385,45 @@ def get_image_size(self):
385385
return self.image_size if hasattr(self, "image_size") else None
386386

387387
def __getitem__(
388-
self, index: Union[int, List[int], torch.LongTensor]
388+
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
389389
) -> "CamerasBase":
390390
"""
391391
Override for the __getitem__ method in TensorProperties which needs to be
392392
refactored.
393393
394394
Args:
395-
index: an int/list/long tensor used to index all the fields in the cameras given by
396-
self._FIELDS.
395+
index: an integer index, list/tensor of integer indices, or tensor of boolean
396+
indicators used to filter all the fields in the cameras given by self._FIELDS.
397397
Returns:
398-
if `index` is an index int/list/long tensor return an instance of the current
399-
cameras class with only the values at the selected index.
398+
an instance of the current cameras class with only the values at the selected index.
400399
"""
401400

402401
kwargs = {}
403402

404-
# pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`.
405-
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)):
406-
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
403+
tensor_types = {
404+
"bool": (torch.BoolTensor, torch.cuda.BoolTensor),
405+
"long": (torch.LongTensor, torch.cuda.LongTensor),
406+
}
407+
if not isinstance(
408+
index, (int, list, *tensor_types["bool"], *tensor_types["long"])
409+
) or (
410+
isinstance(index, list)
411+
and not all(isinstance(i, int) and not isinstance(i, bool) for i in index)
412+
):
413+
msg = (
414+
"Invalid index type, expected int, List[int] or Bool/LongTensor; got %r"
415+
)
407416
raise ValueError(msg % type(index))
408417

409418
if isinstance(index, int):
410419
index = [index]
411420

412-
if max(index) >= len(self):
421+
if isinstance(index, tensor_types["bool"]):
422+
if index.ndim != 1 or index.shape[0] != len(self):
423+
raise ValueError(
424+
f"Boolean index of shape {index.shape} does not match cameras"
425+
)
426+
elif max(index) >= len(self):
413427
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
414428

415429
for field in self._FIELDS:

pytorch3d/structures/meshes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,9 @@ def _set_verts_normals(self, verts_normals) -> None:
472472
def __len__(self) -> int:
473473
return self._N
474474

475-
def __getitem__(self, index) -> "Meshes":
475+
def __getitem__(
476+
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
477+
) -> "Meshes":
476478
"""
477479
Args:
478480
index: Specifying the index of the mesh to retrieve.

pytorch3d/structures/pointclouds.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,10 @@ def _parse_auxiliary_input_list(
360360
def __len__(self) -> int:
361361
return self._N
362362

363-
def __getitem__(self, index) -> "Pointclouds":
363+
def __getitem__(
364+
self,
365+
index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor],
366+
) -> "Pointclouds":
364367
"""
365368
Args:
366369
index: Specifying the index of the cloud to retrieve.

pytorch3d/structures/volumes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,10 @@ def __len__(self) -> int:
501501
return self._densities.shape[0]
502502

503503
def __getitem__(
504-
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor]
504+
self,
505+
index: Union[
506+
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
507+
],
505508
) -> "Volumes":
506509
"""
507510
Args:

pytorch3d/transforms/transform3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __len__(self) -> int:
181181
return self.get_matrix().shape[0]
182182

183183
def __getitem__(
184-
self, index: Union[int, List[int], slice, torch.Tensor]
184+
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
185185
) -> "Transform3d":
186186
"""
187187
Args:

tests/test_cameras.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,8 @@ def test_camera_class_init(self):
884884
self.assertTrue(new_cam.device == device)
885885

886886
def test_getitem(self):
887-
R_matrix = torch.randn((6, 3, 3))
887+
N_CAMERAS = 6
888+
R_matrix = torch.randn((N_CAMERAS, 3, 3))
888889
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
889890

890891
# Check get item returns an instance of the same class
@@ -908,22 +909,39 @@ def test_getitem(self):
908909
self.assertClose(c012.R, R_matrix[0:3, ...])
909910

910911
# Check torch.LongTensor index
911-
index = torch.tensor([1, 3, 5], dtype=torch.int64)
912+
SLICE = [1, 3, 5]
913+
index = torch.tensor(SLICE, dtype=torch.int64)
912914
c135 = cam[index]
913915
self.assertEqual(len(c135), 3)
914916
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
915917
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
916-
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
918+
self.assertClose(c135.R, R_matrix[SLICE, ...])
919+
920+
# Check torch.BoolTensor index
921+
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
922+
index = torch.tensor(bool_slice, dtype=torch.bool)
923+
c135 = cam[index]
924+
self.assertEqual(len(c135), 3)
925+
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
926+
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
927+
self.assertClose(c135.R, R_matrix[SLICE, ...])
917928

918929
# Check errors with get item
919930
with self.assertRaisesRegex(ValueError, "out of bounds"):
920-
cam[6]
931+
cam[N_CAMERAS]
932+
933+
with self.assertRaisesRegex(ValueError, "does not match cameras"):
934+
index = torch.tensor([1, 0, 1], dtype=torch.bool)
935+
cam[index]
921936

922937
with self.assertRaisesRegex(ValueError, "Invalid index type"):
923938
cam[slice(0, 1)]
924939

925940
with self.assertRaisesRegex(ValueError, "Invalid index type"):
926-
index = torch.tensor([1, 3, 5], dtype=torch.float32)
941+
cam[[True, False]]
942+
943+
with self.assertRaisesRegex(ValueError, "Invalid index type"):
944+
index = torch.tensor(SLICE, dtype=torch.float32)
927945
cam[index]
928946

929947
def test_get_full_transform(self):

0 commit comments

Comments
 (0)