Skip to content

Commit db1f7c4

Browse files
bottlerfacebook-github-bot
authored andcommitted
avoid symeig
Summary: Use the newer eigh to avoid deprecation warnings in newer pytorch. Reviewed By: patricklabatut Differential Revision: D34375784 fbshipit-source-id: 40efe0d33fdfa071fba80fc97ed008cbfd2ef249
1 parent 59972b1 commit db1f7c4

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

pytorch3d/common/compat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,12 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
4949
# PyTorch version >= 1.9
5050
return torch.linalg.qr(A)
5151
return torch.qr(A)
52+
53+
54+
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
55+
"""
56+
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
57+
"""
58+
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
59+
return torch.linalg.eigh(A)
60+
return torch.symeig(A, eigenvalues=True)

pytorch3d/ops/perspective_n_points.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
import torch.nn.functional as F
19+
from pytorch3d.common.compat import eigh
1920
from pytorch3d.ops import points_alignment, utils as oputil
2021

2122

@@ -105,7 +106,7 @@ def _null_space(m, kernel_dim):
105106
kernel vectors, of size B x kernel_dim
106107
"""
107108
mTm = torch.bmm(m.transpose(1, 2), m)
108-
s, v = torch.symeig(mTm, eigenvectors=True)
109+
s, v = eigh(mTm)
109110
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
110111

111112

pytorch3d/ops/points_normals.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from typing import TYPE_CHECKING, Tuple, Union
88

99
import torch
10+
from pytorch3d.common.compat import eigh
11+
from pytorch3d.common.workaround import symeig3x3
1012

11-
from ..common.workaround import symeig3x3
1213
from .utils import convert_pointclouds_to_tensor, get_point_covariances
1314

1415

@@ -139,14 +140,14 @@ def estimate_pointcloud_local_coord_frames(
139140

140141
# get the local coord frames as principal directions of
141142
# the per-point covariance
142-
# this is done with torch.symeig, which returns the
143+
# this is done with torch.symeig / torch.linalg.eigh, which returns the
143144
# eigenvectors (=principal directions) in an ascending order of their
144-
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
145-
# corresponds to the normal direction
145+
# corresponding eigenvalues, and the smallest eigenvalue's eigenvector
146+
# corresponds to the normal direction; or with a custom equivalent.
146147
if use_symeig_workaround:
147148
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
148149
else:
149-
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
150+
curvatures, local_coord_frames = eigh(cov)
150151

151152
# disambiguate the directions of individual principal vectors
152153
if disambiguate_directions:

0 commit comments

Comments
 (0)