88from typing import Tuple
99
1010import torch
11+ from pytorch3d .transforms import rotation_conversions
1112
1213from ..transforms import acos_linear_extrapolation
1314
@@ -160,19 +161,10 @@ def _so3_exp_map(
160161 nrms = (log_rot * log_rot ).sum (1 )
161162 # phis ... rotation angles
162163 rot_angles = torch .clamp (nrms , eps ).sqrt ()
163- # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
164- rot_angles_inv = 1.0 / rot_angles
165- fac1 = rot_angles_inv * rot_angles .sin ()
166- fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles .cos ())
167164 skews = hat (log_rot )
168165 skews_square = torch .bmm (skews , skews )
169166
170- R = (
171- fac1 [:, None , None ] * skews
172- # pyre-fixme[16]: `float` has no attribute `__getitem__`.
173- + fac2 [:, None , None ] * skews_square
174- + torch .eye (3 , dtype = log_rot .dtype , device = log_rot .device )[None ]
175- )
167+ R = rotation_conversions .axis_angle_to_matrix (log_rot )
176168
177169 return R , rot_angles , skews , skews_square
178170
@@ -183,49 +175,23 @@ def so3_log_map(
183175 """
184176 Convert a batch of 3x3 rotation matrices `R`
185177 to a batch of 3-dimensional matrix logarithms of rotation matrices
186- The conversion has a singularity around `(R=I)` which is handled
187- by clamping controlled with the `eps` and `cos_bound` arguments.
178+ The conversion has a singularity around `(R=I)`.
188179
189180 Args:
190181 R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
191- eps: A float constant handling the conversion singularity.
192- cos_bound: Clamps the cosine of the rotation angle to
193- [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
194- of the `acos` call when computing `so3_rotation_angle`.
195- Note that the non-finite outputs/gradients are returned when
196- the rotation angle is close to 0 or π.
182+ eps: (unused, for backward compatibility)
183+ cos_bound: (unused, for backward compatibility)
197184
198185 Returns:
199186 Batch of logarithms of input rotation matrices
200187 of shape `(minibatch, 3)`.
201-
202- Raises:
203- ValueError if `R` is of incorrect shape.
204- ValueError if `R` has an unexpected trace.
205188 """
206189
207190 N , dim1 , dim2 = R .shape
208191 if dim1 != 3 or dim2 != 3 :
209192 raise ValueError ("Input has to be a batch of 3x3 Tensors." )
210193
211- phi = so3_rotation_angle (R , cos_bound = cos_bound , eps = eps )
212-
213- phi_sin = torch .sin (phi )
214-
215- # We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
216- # Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
217- # 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
218- phi_factor = torch .empty_like (phi )
219- ok_denom = phi_sin .abs () > (0.5 * eps )
220- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
221- phi_factor [~ ok_denom ] = 0.5 + (phi [~ ok_denom ] ** 2 ) * (1.0 / 12 )
222- phi_factor [ok_denom ] = phi [ok_denom ] / (2.0 * phi_sin [ok_denom ])
223-
224- log_rot_hat = phi_factor [:, None , None ] * (R - R .permute (0 , 2 , 1 ))
225-
226- log_rot = hat_inv (log_rot_hat )
227-
228- return log_rot
194+ return rotation_conversions .matrix_to_axis_angle (R )
229195
230196
231197def hat_inv (h : torch .Tensor ) -> torch .Tensor :
0 commit comments