Skip to content

Commit b0462d8

Browse files
Allow indexing for classes inheriting Transform3d (#1801)
Summary: Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the `Transforms3d` class. For instance: ```python from pytorch3d import transforms N = 10 r = transforms.random_rotations(N) T = transforms.Transform3d().rotate(R=r) R = transforms.Rotate(r) x = T[0] # ok x = R[0] # TypeError: __init__() got an unexpected keyword argument 'matrix' ``` This is because all these classes (namely `Rotate`, `Translate`, `Scale`, `RotateAxisAngle`) inherit the `__getitem__()` method from `Transform3d` which has the [following code on line 201](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/transform3d.py#L201): ```python return self.__class__(matrix=self.get_matrix()[index]) ``` The four classes inheriting `Transform3d` are not initialized through a matrix argument, hence they error. I propose to modify the `__getitem__()` method of the `Transform3d` class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the `_matrix` attribute manually. Thus, instead of ```python return self.__class__(matrix=self.get_matrix()[index]) ``` I propose to do: ```python instance = self.__class__.__new__(self.__class__) instance._matrix = self.get_matrix()[index] return instance ``` As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms. Pull Request resolved: #1801 Reviewed By: MichaelRamamonjisoa Differential Revision: D58410389 Pulled By: bottler fbshipit-source-id: f371e4c63d2ae4c927a7ad48c2de8862761078de
1 parent b66d17a commit b0462d8

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

pytorch3d/transforms/transform3d.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,22 @@ def _get_matrix_inverse(self) -> torch.Tensor:
564564
i_matrix = self._matrix * inv_mask
565565
return i_matrix
566566

567+
def __getitem__(
568+
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
569+
) -> "Transform3d":
570+
"""
571+
Args:
572+
index: Specifying the index of the transform to retrieve.
573+
Can be an int, slice, list of ints, boolean, long tensor.
574+
Supports negative indices.
575+
576+
Returns:
577+
Transform3d object with selected transforms. The tensors are not cloned.
578+
"""
579+
if isinstance(index, int):
580+
index = [index]
581+
return self.__class__(self.get_matrix()[index, 3, :3])
582+
567583

568584
class Scale(Transform3d):
569585
def __init__(
@@ -613,6 +629,26 @@ def _get_matrix_inverse(self) -> torch.Tensor:
613629
imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
614630
return imat
615631

632+
def __getitem__(
633+
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
634+
) -> "Transform3d":
635+
"""
636+
Args:
637+
index: Specifying the index of the transform to retrieve.
638+
Can be an int, slice, list of ints, boolean, long tensor.
639+
Supports negative indices.
640+
641+
Returns:
642+
Transform3d object with selected transforms. The tensors are not cloned.
643+
"""
644+
if isinstance(index, int):
645+
index = [index]
646+
mat = self.get_matrix()[index]
647+
x = mat[:, 0, 0]
648+
y = mat[:, 1, 1]
649+
z = mat[:, 2, 2]
650+
return self.__class__(x, y, z)
651+
616652

617653
class Rotate(Transform3d):
618654
def __init__(
@@ -655,6 +691,22 @@ def _get_matrix_inverse(self) -> torch.Tensor:
655691
"""
656692
return self._matrix.permute(0, 2, 1).contiguous()
657693

694+
def __getitem__(
695+
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
696+
) -> "Transform3d":
697+
"""
698+
Args:
699+
index: Specifying the index of the transform to retrieve.
700+
Can be an int, slice, list of ints, boolean, long tensor.
701+
Supports negative indices.
702+
703+
Returns:
704+
Transform3d object with selected transforms. The tensors are not cloned.
705+
"""
706+
if isinstance(index, int):
707+
index = [index]
708+
return self.__class__(self.get_matrix()[index, :3, :3])
709+
658710

659711
class RotateAxisAngle(Rotate):
660712
def __init__(

tests/test_transforms.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,15 @@ def test_inverse(self):
685685
self.assertTrue(torch.allclose(im, im_comp))
686686
self.assertTrue(torch.allclose(im, im_2))
687687

688+
def test_get_item(self, batch_size=5):
689+
device = torch.device("cuda:0")
690+
xyz = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32)
691+
t3d = Translate(xyz)
692+
index = 1
693+
t3d_selected = t3d[index]
694+
self.assertEqual(len(t3d_selected), 1)
695+
self.assertIsInstance(t3d_selected, Translate)
696+
688697

689698
class TestScale(unittest.TestCase):
690699
def test_single_python_scalar(self):
@@ -871,6 +880,15 @@ def test_inverse(self):
871880
self.assertTrue(torch.allclose(im, im_comp))
872881
self.assertTrue(torch.allclose(im, im_2))
873882

883+
def test_get_item(self, batch_size=5):
884+
device = torch.device("cuda:0")
885+
s = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32)
886+
t3d = Scale(s)
887+
index = 1
888+
t3d_selected = t3d[index]
889+
self.assertEqual(len(t3d_selected), 1)
890+
self.assertIsInstance(t3d_selected, Scale)
891+
874892

875893
class TestTransformBroadcast(unittest.TestCase):
876894
def test_broadcast_transform_points(self):
@@ -986,6 +1004,15 @@ def test_inverse(self, batch_size=5):
9861004
self.assertTrue(torch.allclose(im, im_comp, atol=1e-4))
9871005
self.assertTrue(torch.allclose(im, im_2, atol=1e-4))
9881006

1007+
def test_get_item(self, batch_size=5):
1008+
device = torch.device("cuda:0")
1009+
r = random_rotations(batch_size, dtype=torch.float32, device=device)
1010+
t3d = Rotate(r)
1011+
index = 1
1012+
t3d_selected = t3d[index]
1013+
self.assertEqual(len(t3d_selected), 1)
1014+
self.assertIsInstance(t3d_selected, Rotate)
1015+
9891016

9901017
class TestRotateAxisAngle(unittest.TestCase):
9911018
def test_rotate_x_python_scalar(self):

0 commit comments

Comments
 (0)