1
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
2
3
3
4
+ import math
4
5
import unittest
5
6
6
7
import numpy as np
7
8
import torch
9
+ from common_testing import TestCaseMixin
8
10
from pytorch3d .transforms .so3 import (
9
11
hat ,
10
12
so3_exponential_map ,
13
15
)
14
16
15
17
16
- class TestSO3 (unittest .TestCase ):
18
+ class TestSO3 (TestCaseMixin , unittest .TestCase ):
17
19
def setUp (self ) -> None :
18
20
super ().setUp ()
19
21
torch .manual_seed (42 )
@@ -55,9 +57,8 @@ def test_determinant(self):
55
57
"""
56
58
log_rot = TestSO3 .init_log_rot (batch_size = 30 )
57
59
Rs = so3_exponential_map (log_rot )
58
- for R in Rs :
59
- det = np .linalg .det (R .cpu ().numpy ())
60
- self .assertAlmostEqual (float (det ), 1.0 , 5 )
60
+ dets = torch .det (Rs )
61
+ self .assertClose (dets , torch .ones_like (dets ), atol = 1e-4 )
61
62
62
63
def test_cross (self ):
63
64
"""
@@ -70,8 +71,7 @@ def test_cross(self):
70
71
hat_a = hat (a )
71
72
cross = torch .bmm (hat_a , b [:, :, None ])[:, :, 0 ]
72
73
torch_cross = torch .cross (a , b , dim = 1 )
73
- max_df = (cross - torch_cross ).abs ().max ()
74
- self .assertAlmostEqual (float (max_df ), 0.0 , 5 )
74
+ self .assertClose (torch_cross , cross , atol = 1e-4 )
75
75
76
76
def test_bad_so3_input_value_err (self ):
77
77
"""
@@ -126,37 +126,63 @@ def test_so3_log_singularity(self, batch_size: int = 100):
126
126
"""
127
127
# generate random rotations with a tiny angle
128
128
device = torch .device ("cuda:0" )
129
- r = torch .eye (3 , device = device )[None ].repeat ((batch_size , 1 , 1 ))
130
- r += torch .randn ((batch_size , 3 , 3 ), device = device ) * 1e-3
131
- r = torch .stack ([torch .qr (r_ )[0 ] for r_ in r ])
129
+ identity = torch .eye (3 , device = device )
130
+ rot180 = identity * torch .tensor ([[1.0 , - 1.0 , - 1.0 ]], device = device )
131
+ r = [identity , rot180 ]
132
+ r .extend (
133
+ [
134
+ torch .qr (identity + torch .randn_like (identity ) * 1e-4 )[0 ]
135
+ for _ in range (batch_size - 2 )
136
+ ]
137
+ )
138
+ r = torch .stack (r )
132
139
# the log of the rotation matrix r
133
140
r_log = so3_log_map (r )
134
141
# tests whether all outputs are finite
135
142
r_sum = float (r_log .sum ())
136
143
self .assertEqual (r_sum , r_sum )
137
144
145
+ def test_so3_log_to_exp_to_log_to_exp (self , batch_size : int = 100 ):
146
+ """
147
+ Check that
148
+ `so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
149
+ == so3_exponential_map(log_rot)`
150
+ for a randomly generated batch of rotation matrix logarithms `log_rot`.
151
+ Unlike `test_so3_log_to_exp_to_log`, this test allows to check the
152
+ correctness of converting `log_rot` which contains values > math.pi.
153
+ """
154
+ log_rot = 2.0 * TestSO3 .init_log_rot (batch_size = batch_size )
155
+ # check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
156
+ log_rot [:3 ] = 0
157
+ log_rot [1 , 0 ] = math .pi
158
+ log_rot [2 , 0 ] = 2.0 * math .pi
159
+ log_rot [3 , 0 ] = 3.0 * math .pi
160
+ rot = so3_exponential_map (log_rot , eps = 1e-8 )
161
+ rot_ = so3_exponential_map (so3_log_map (rot , eps = 1e-8 ), eps = 1e-8 )
162
+ angles = so3_relative_angle (rot , rot_ )
163
+ self .assertClose (angles , torch .zeros_like (angles ), atol = 0.01 )
164
+
138
165
def test_so3_log_to_exp_to_log (self , batch_size : int = 100 ):
139
166
"""
140
167
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
141
168
a randomly generated batch of rotation matrix logarithms `log_rot`.
142
169
"""
143
170
log_rot = TestSO3 .init_log_rot (batch_size = batch_size )
171
+ # check also the singular cases where rot. angle = 0
172
+ log_rot [:1 ] = 0
144
173
log_rot_ = so3_log_map (so3_exponential_map (log_rot ))
145
- max_df = (log_rot - log_rot_ ).abs ().max ()
146
- self .assertAlmostEqual (float (max_df ), 0.0 , 4 )
174
+ self .assertClose (log_rot , log_rot_ , atol = 1e-4 )
147
175
148
176
def test_so3_exp_to_log_to_exp (self , batch_size : int = 100 ):
149
177
"""
150
178
Check that `so3_exponential_map(so3_log_map(R))==R` for
151
179
a batch of randomly generated rotation matrices `R`.
152
180
"""
153
181
rot = TestSO3 .init_rot (batch_size = batch_size )
154
- rot_ = so3_exponential_map (so3_log_map (rot ) )
182
+ rot_ = so3_exponential_map (so3_log_map (rot , eps = 1e-8 ), eps = 1e-8 )
155
183
angles = so3_relative_angle (rot , rot_ )
156
- max_angle = angles .max ()
157
- # a lot of precision lost here :(
158
- # TODO: fix this test??
159
- self .assertTrue (np .allclose (float (max_angle ), 0.0 , atol = 0.1 ))
184
+ # TODO: a lot of precision lost here ...
185
+ self .assertClose (angles , torch .zeros_like (angles ), atol = 0.1 )
160
186
161
187
def test_so3_cos_angle (self , batch_size : int = 100 ):
162
188
"""
@@ -168,7 +194,7 @@ def test_so3_cos_angle(self, batch_size: int = 100):
168
194
rot2 = TestSO3 .init_rot (batch_size = batch_size )
169
195
angles = so3_relative_angle (rot1 , rot2 , cos_angle = False ).cos ()
170
196
angles_ = so3_relative_angle (rot1 , rot2 , cos_angle = True )
171
- self .assertTrue ( torch . allclose ( angles , angles_ ) )
197
+ self .assertClose ( angles , angles_ )
172
198
173
199
@staticmethod
174
200
def so3_expmap (batch_size : int = 10 ):
0 commit comments