Skip to content

Commit b602edc

Browse files
janEbertfacebook-github-bot
authored andcommitted
Fix dtype propagation (#1141)
Summary: Previously, dtypes were not propagated correctly in composed transforms, resulting in errors when different dtypes were mixed. Even specifying a dtype in the constructor does not fix this. Neither does specifying the dtype for each composition function invocation (e.g. as a `kwarg` in `rotate_axis_angle`). With the change, I also had to modify the default dtype of `RotateAxisAngle`, which was `torch.float64`; it is now `torch.float32` like for all other transforms. This was required because the fix in propagation broke some tests due to dtype mismatches. This change in default dtype in turn broke two tests due to precision changes (calculations that were previously done in `torch.float64` were now done in `torch.float32`), so I changed the precision tolerances to be less strict. I chose the lowest power of ten that passed the tests here. Pull Request resolved: #1141 Reviewed By: patricklabatut Differential Revision: D35192970 Pulled By: bottler fbshipit-source-id: ba0293e8b3595dfc94b3cf8048e50b7a5e5ed7cf
1 parent 21262e3 commit b602edc

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

pytorch3d/transforms/transform3d.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -390,16 +390,24 @@ def transform_normals(self, normals) -> torch.Tensor:
390390
return normals_out
391391

392392
def translate(self, *args, **kwargs) -> "Transform3d":
393-
return self.compose(Translate(device=self.device, *args, **kwargs))
393+
return self.compose(
394+
Translate(device=self.device, dtype=self.dtype, *args, **kwargs)
395+
)
394396

395397
def scale(self, *args, **kwargs) -> "Transform3d":
396-
return self.compose(Scale(device=self.device, *args, **kwargs))
398+
return self.compose(
399+
Scale(device=self.device, dtype=self.dtype, *args, **kwargs)
400+
)
397401

398402
def rotate(self, *args, **kwargs) -> "Transform3d":
399-
return self.compose(Rotate(device=self.device, *args, **kwargs))
403+
return self.compose(
404+
Rotate(device=self.device, dtype=self.dtype, *args, **kwargs)
405+
)
400406

401407
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
402-
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
408+
return self.compose(
409+
RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs)
410+
)
403411

404412
def clone(self) -> "Transform3d":
405413
"""
@@ -488,7 +496,7 @@ def __init__(
488496
- A 1D torch tensor
489497
"""
490498
xyz = _handle_input(x, y, z, dtype, device, "Translate")
491-
super().__init__(device=xyz.device)
499+
super().__init__(device=xyz.device, dtype=dtype)
492500
N = xyz.shape[0]
493501

494502
mat = torch.eye(4, dtype=dtype, device=self.device)
@@ -532,7 +540,7 @@ def __init__(
532540
- 1D torch tensor
533541
"""
534542
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
535-
super().__init__(device=xyz.device)
543+
super().__init__(device=xyz.device, dtype=dtype)
536544
N = xyz.shape[0]
537545

538546
# TODO: Can we do this all in one go somehow?
@@ -571,7 +579,7 @@ def __init__(
571579
572580
"""
573581
device_ = get_device(R, device)
574-
super().__init__(device=device_)
582+
super().__init__(device=device_, dtype=dtype)
575583
if R.dim() == 2:
576584
R = R[None]
577585
if R.shape[-2:] != (3, 3):
@@ -598,7 +606,7 @@ def __init__(
598606
angle,
599607
axis: str = "X",
600608
degrees: bool = True,
601-
dtype: torch.dtype = torch.float64,
609+
dtype: torch.dtype = torch.float32,
602610
device: Optional[Device] = None,
603611
) -> None:
604612
"""
@@ -629,7 +637,7 @@ def __init__(
629637
# is for transforming column vectors. Therefore we transpose this matrix.
630638
# R will always be of shape (N, 3, 3)
631639
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
632-
super().__init__(device=angle.device, R=R)
640+
super().__init__(device=angle.device, R=R, dtype=dtype)
633641

634642

635643
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
@@ -646,8 +654,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
646654
c = torch.tensor(c, dtype=dtype, device=device)
647655
if c.dim() == 0:
648656
c = c.view(1)
649-
if c.device != device:
650-
c = c.to(device=device)
657+
if c.device != device or c.dtype != dtype:
658+
c = c.to(device=device, dtype=dtype)
651659
return c
652660

653661

@@ -696,7 +704,7 @@ def _handle_input(
696704
if y is not None or z is not None:
697705
msg = "Expected y and z to be None (in %s)" % name
698706
raise ValueError(msg)
699-
return x.to(device=device_)
707+
return x.to(device=device_, dtype=dtype)
700708

701709
if allow_singleton and y is None and z is None:
702710
y = x

tests/test_transforms.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,36 @@ def test_to(self):
8787
t = t.cuda()
8888
t = t.cpu()
8989

90+
def test_dtype_propagation(self):
91+
"""
92+
Check that a given dtype is correctly passed along to child
93+
transformations.
94+
"""
95+
# Use at least two dtypes so we avoid only testing on the
96+
# default dtype.
97+
for dtype in [torch.float32, torch.float64]:
98+
R = torch.tensor(
99+
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]],
100+
dtype=dtype,
101+
)
102+
tf = (
103+
Transform3d(dtype=dtype)
104+
.rotate(R)
105+
.rotate_axis_angle(
106+
R[0],
107+
"X",
108+
)
109+
.translate(3, 2, 1)
110+
.scale(0.5)
111+
)
112+
113+
self.assertEqual(tf.dtype, dtype)
114+
for inner_tf in tf._transforms:
115+
self.assertEqual(inner_tf.dtype, dtype)
116+
117+
transformed = tf.transform_points(R)
118+
self.assertEqual(transformed.dtype, dtype)
119+
90120
def test_clone(self):
91121
"""
92122
Check that cloned transformations contain different _matrix objects.
@@ -219,8 +249,8 @@ def test_rotate_axis_angle(self):
219249
normals_out_expected = torch.tensor(
220250
[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]
221251
).view(1, 3, 3)
222-
self.assertTrue(torch.allclose(points_out, points_out_expected))
223-
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
252+
self.assertTrue(torch.allclose(points_out, points_out_expected, atol=1e-7))
253+
self.assertTrue(torch.allclose(normals_out, normals_out_expected, atol=1e-7))
224254

225255
def test_transform_points_fail(self):
226256
t1 = Scale(0.1, 0.1, 0.1)
@@ -951,7 +981,7 @@ def test_rotate_x_python_scalar(self):
951981
self.assertTrue(
952982
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
953983
)
954-
self.assertTrue(torch.allclose(t._matrix, matrix))
984+
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
955985

956986
def test_rotate_x_torch_scalar(self):
957987
angle = torch.tensor(90.0)

0 commit comments

Comments
 (0)