You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Allow indexing for classes inheriting Transform3d (#1801)
Summary:
Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the `Transforms3d` class.
For instance:
```python
from pytorch3d import transforms
N = 10
r = transforms.random_rotations(N)
T = transforms.Transform3d().rotate(R=r)
R = transforms.Rotate(r)
x = T[0] # ok
x = R[0] # TypeError: __init__() got an unexpected keyword argument 'matrix'
```
This is because all these classes (namely `Rotate`, `Translate`, `Scale`, `RotateAxisAngle`) inherit the `__getitem__()` method from `Transform3d` which has the [following code on line 201](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/transform3d.py#L201):
```python
return self.__class__(matrix=self.get_matrix()[index])
```
The four classes inheriting `Transform3d` are not initialized through a matrix argument, hence they error.
I propose to modify the `__getitem__()` method of the `Transform3d` class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the `_matrix` attribute manually. Thus, instead of
```python
return self.__class__(matrix=self.get_matrix()[index])
```
I propose to do:
```python
instance = self.__class__.__new__(self.__class__)
instance._matrix = self.get_matrix()[index]
return instance
```
As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms.
Pull Request resolved: #1801
Reviewed By: MichaelRamamonjisoa
Differential Revision: D58410389
Pulled By: bottler
fbshipit-source-id: f371e4c63d2ae4c927a7ad48c2de8862761078de
0 commit comments