Skip to content

Commit cd5af25

Browse files
theschnitzfacebook-github-bot
authored andcommitted
Update Rotate transform to use device of input rotation
Summary: Currently the Rotate transform does not consider the R's device at all, resulting in errors if you're expecting it to be on cuda but it gets the default casting to cpu. This updates the transform to respect R's device. Reviewed By: nikhilaravi Differential Revision: D27828118 fbshipit-source-id: ddd99f73eadbd990688eb22f3d1ffbacbe168c81
1 parent c9dea62 commit cd5af25

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

pytorch3d/transforms/transform3d.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def cuda(self):
434434

435435

436436
class Translate(Transform3d):
437-
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
437+
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
438438
"""
439439
Create a new Transform3d representing 3D translations.
440440
@@ -448,11 +448,11 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
448448
- A torch scalar
449449
- A 1D torch tensor
450450
"""
451-
super().__init__(device=device)
452451
xyz = _handle_input(x, y, z, dtype, device, "Translate")
452+
super().__init__(device=xyz.device)
453453
N = xyz.shape[0]
454454

455-
mat = torch.eye(4, dtype=dtype, device=device)
455+
mat = torch.eye(4, dtype=dtype, device=self.device)
456456
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
457457
mat[:, 3, :3] = xyz
458458
self._matrix = mat
@@ -468,7 +468,7 @@ def _get_matrix_inverse(self):
468468

469469

470470
class Scale(Transform3d):
471-
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
471+
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
472472
"""
473473
A Transform3d representing a scaling operation, with different scale
474474
factors along each coordinate axis.
@@ -485,12 +485,12 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
485485
- torch scalar
486486
- 1D torch tensor
487487
"""
488-
super().__init__(device=device)
489488
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
489+
super().__init__(device=xyz.device)
490490
N = xyz.shape[0]
491491

492492
# TODO: Can we do this all in one go somehow?
493-
mat = torch.eye(4, dtype=dtype, device=device)
493+
mat = torch.eye(4, dtype=dtype, device=self.device)
494494
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
495495
mat[:, 0, 0] = xyz[:, 0]
496496
mat[:, 1, 1] = xyz[:, 1]
@@ -509,7 +509,7 @@ def _get_matrix_inverse(self):
509509

510510
class Rotate(Transform3d):
511511
def __init__(
512-
self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5
512+
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5
513513
):
514514
"""
515515
Create a new Transform3d representing 3D rotation using a rotation
@@ -520,6 +520,7 @@ def __init__(
520520
orthogonal_tol: tolerance for the test of the orthogonality of R
521521
522522
"""
523+
device = _get_device(R, device)
523524
super().__init__(device=device)
524525
if R.dim() == 2:
525526
R = R[None]
@@ -548,7 +549,7 @@ def __init__(
548549
axis: str = "X",
549550
degrees: bool = True,
550551
dtype=torch.float64,
551-
device="cpu",
552+
device=None,
552553
):
553554
"""
554555
Create a new Transform3d representing 3D rotation about an axis
@@ -578,7 +579,7 @@ def __init__(
578579
# is for transforming column vectors. Therefore we transpose this matrix.
579580
# R will always be of shape (N, 3, 3)
580581
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
581-
super().__init__(device=device, R=R)
582+
super().__init__(device=angle.device, R=R)
582583

583584

584585
def _handle_coord(c, dtype, device):
@@ -595,9 +596,24 @@ def _handle_coord(c, dtype, device):
595596
c = torch.tensor(c, dtype=dtype, device=device)
596597
if c.dim() == 0:
597598
c = c.view(1)
599+
if c.device != device:
600+
c = c.to(device=device)
598601
return c
599602

600603

604+
def _get_device(x, device=None):
605+
if device is not None:
606+
# User overriding device, leave
607+
device = device
608+
elif torch.is_tensor(x):
609+
# Set device based on input tensor
610+
device = x.device
611+
else:
612+
# Default device is cpu
613+
device = "cpu"
614+
return device
615+
616+
601617
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
602618
"""
603619
Helper function to handle parsing logic for building transforms. The output
@@ -626,6 +642,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
626642
Returns:
627643
xyz: Tensor of shape (N, 3)
628644
"""
645+
device = _get_device(x, device)
629646
# If x is actually a tensor of shape (N, 3) then just return it
630647
if torch.is_tensor(x) and x.dim() == 2:
631648
if x.shape[1] != 3:
@@ -634,7 +651,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
634651
if y is not None or z is not None:
635652
msg = "Expected y and z to be None (in %s)" % name
636653
raise ValueError(msg)
637-
return x
654+
return x.to(device=device)
638655

639656
if allow_singleton and y is None and z is None:
640657
y = x
@@ -665,6 +682,7 @@ def _handle_angle_input(x, dtype, device, name: str):
665682
- Python scalar
666683
- Torch scalar
667684
"""
685+
device = _get_device(x, device)
668686
if torch.is_tensor(x) and x.dim() > 1:
669687
msg = "Expected tensor of shape (N,); got %r (in %s)"
670688
raise ValueError(msg % (x.shape, name))

0 commit comments

Comments
 (0)