Skip to content

Commit 34a0df0

Browse files
davnov134facebook-github-bot
authored andcommitted
SO3 log map fix for singularity at PI
Summary: Fixes the case where the rotation angle is exactly 0/PI. Added a test for `so3_log_map(identity_matrix)`. Reviewed By: nikhilaravi Differential Revision: D21477078 fbshipit-source-id: adff804da97f6f0d4f50aa1f6904a34832cb8bfe
1 parent 17ca6ec commit 34a0df0

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

pytorch3d/transforms/so3.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,14 @@ def so3_log_map(R, eps: float = 0.0001):
152152

153153
phi = so3_rotation_angle(R)
154154

155-
phi_valid = torch.clamp(phi.abs(), eps) * phi.sign()
155+
phi_sin = phi.sin()
156156

157-
log_rot_hat = (phi_valid / (2.0 * phi_valid.sin()))[:, None, None] * (
158-
R - R.permute(0, 2, 1)
157+
phi_denom = (
158+
torch.clamp(phi_sin.abs(), eps) * phi_sin.sign()
159+
+ (phi_sin == 0).type_as(phi) * eps
159160
)
161+
162+
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
160163
log_rot = hat_inv(log_rot_hat)
161164

162165
return log_rot

tests/test_so3.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33

4+
import math
45
import unittest
56

67
import numpy as np
78
import torch
9+
from common_testing import TestCaseMixin
810
from pytorch3d.transforms.so3 import (
911
hat,
1012
so3_exponential_map,
@@ -13,7 +15,7 @@
1315
)
1416

1517

16-
class TestSO3(unittest.TestCase):
18+
class TestSO3(TestCaseMixin, unittest.TestCase):
1719
def setUp(self) -> None:
1820
super().setUp()
1921
torch.manual_seed(42)
@@ -55,9 +57,8 @@ def test_determinant(self):
5557
"""
5658
log_rot = TestSO3.init_log_rot(batch_size=30)
5759
Rs = so3_exponential_map(log_rot)
58-
for R in Rs:
59-
det = np.linalg.det(R.cpu().numpy())
60-
self.assertAlmostEqual(float(det), 1.0, 5)
60+
dets = torch.det(Rs)
61+
self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
6162

6263
def test_cross(self):
6364
"""
@@ -70,8 +71,7 @@ def test_cross(self):
7071
hat_a = hat(a)
7172
cross = torch.bmm(hat_a, b[:, :, None])[:, :, 0]
7273
torch_cross = torch.cross(a, b, dim=1)
73-
max_df = (cross - torch_cross).abs().max()
74-
self.assertAlmostEqual(float(max_df), 0.0, 5)
74+
self.assertClose(torch_cross, cross, atol=1e-4)
7575

7676
def test_bad_so3_input_value_err(self):
7777
"""
@@ -126,37 +126,63 @@ def test_so3_log_singularity(self, batch_size: int = 100):
126126
"""
127127
# generate random rotations with a tiny angle
128128
device = torch.device("cuda:0")
129-
r = torch.eye(3, device=device)[None].repeat((batch_size, 1, 1))
130-
r += torch.randn((batch_size, 3, 3), device=device) * 1e-3
131-
r = torch.stack([torch.qr(r_)[0] for r_ in r])
129+
identity = torch.eye(3, device=device)
130+
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
131+
r = [identity, rot180]
132+
r.extend(
133+
[
134+
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
135+
for _ in range(batch_size - 2)
136+
]
137+
)
138+
r = torch.stack(r)
132139
# the log of the rotation matrix r
133140
r_log = so3_log_map(r)
134141
# tests whether all outputs are finite
135142
r_sum = float(r_log.sum())
136143
self.assertEqual(r_sum, r_sum)
137144

145+
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
146+
"""
147+
Check that
148+
`so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
149+
== so3_exponential_map(log_rot)`
150+
for a randomly generated batch of rotation matrix logarithms `log_rot`.
151+
Unlike `test_so3_log_to_exp_to_log`, this test allows to check the
152+
correctness of converting `log_rot` which contains values > math.pi.
153+
"""
154+
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
155+
# check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
156+
log_rot[:3] = 0
157+
log_rot[1, 0] = math.pi
158+
log_rot[2, 0] = 2.0 * math.pi
159+
log_rot[3, 0] = 3.0 * math.pi
160+
rot = so3_exponential_map(log_rot, eps=1e-8)
161+
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
162+
angles = so3_relative_angle(rot, rot_)
163+
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
164+
138165
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
139166
"""
140167
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
141168
a randomly generated batch of rotation matrix logarithms `log_rot`.
142169
"""
143170
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
171+
# check also the singular cases where rot. angle = 0
172+
log_rot[:1] = 0
144173
log_rot_ = so3_log_map(so3_exponential_map(log_rot))
145-
max_df = (log_rot - log_rot_).abs().max()
146-
self.assertAlmostEqual(float(max_df), 0.0, 4)
174+
self.assertClose(log_rot, log_rot_, atol=1e-4)
147175

148176
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
149177
"""
150178
Check that `so3_exponential_map(so3_log_map(R))==R` for
151179
a batch of randomly generated rotation matrices `R`.
152180
"""
153181
rot = TestSO3.init_rot(batch_size=batch_size)
154-
rot_ = so3_exponential_map(so3_log_map(rot))
182+
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
155183
angles = so3_relative_angle(rot, rot_)
156-
max_angle = angles.max()
157-
# a lot of precision lost here :(
158-
# TODO: fix this test??
159-
self.assertTrue(np.allclose(float(max_angle), 0.0, atol=0.1))
184+
# TODO: a lot of precision lost here ...
185+
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
160186

161187
def test_so3_cos_angle(self, batch_size: int = 100):
162188
"""
@@ -168,7 +194,7 @@ def test_so3_cos_angle(self, batch_size: int = 100):
168194
rot2 = TestSO3.init_rot(batch_size=batch_size)
169195
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
170196
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
171-
self.assertTrue(torch.allclose(angles, angles_))
197+
self.assertClose(angles, angles_)
172198

173199
@staticmethod
174200
def so3_expmap(batch_size: int = 10):

0 commit comments

Comments
 (0)