Skip to content

Commit c2862ff

Browse files
bottlerfacebook-github-bot
authored andcommitted
use workaround for points_normals
Summary: Use existing workaround for batched 3x3 symeig because it is faster than torch.symeig. Added benchmark showing speedup. True = workaround. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- normals_True_3000 16237 17233 31 normals_True_6000 33028 33391 16 normals_False_3000 18623069 18623069 1 normals_False_6000 36535475 36535475 1 ``` Should help #988 Reviewed By: nikhilaravi Differential Revision: D33660585 fbshipit-source-id: d1162b277f5d61ed67e367057a61f25e03888dce
1 parent 5053142 commit c2862ff

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

pytorch3d/ops/points_normals.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010

11+
from ..common.workaround import symeig3x3
1112
from .utils import convert_pointclouds_to_tensor, get_point_covariances
1213

1314

@@ -19,6 +20,8 @@ def estimate_pointcloud_normals(
1920
pointclouds: Union[torch.Tensor, "Pointclouds"],
2021
neighborhood_size: int = 50,
2122
disambiguate_directions: bool = True,
23+
*,
24+
use_symeig_workaround: bool = True,
2225
) -> torch.Tensor:
2326
"""
2427
Estimates the normals of a batch of `pointclouds`.
@@ -33,6 +36,8 @@ def estimate_pointcloud_normals(
3336
geometry around each point.
3437
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
3538
ensure sign consistency of the normals of neighboring points.
39+
**use_symeig_workaround**: If `True`, uses a custom eigenvalue
40+
calculation.
3641
3742
Returns:
3843
**normals**: A tensor of normals for each input point
@@ -48,6 +53,7 @@ def estimate_pointcloud_normals(
4853
pointclouds,
4954
neighborhood_size=neighborhood_size,
5055
disambiguate_directions=disambiguate_directions,
56+
use_symeig_workaround=use_symeig_workaround,
5157
)
5258

5359
# the normals correspond to the first vector of each local coord frame
@@ -60,6 +66,8 @@ def estimate_pointcloud_local_coord_frames(
6066
pointclouds: Union[torch.Tensor, "Pointclouds"],
6167
neighborhood_size: int = 50,
6268
disambiguate_directions: bool = True,
69+
*,
70+
use_symeig_workaround: bool = True,
6371
) -> Tuple[torch.Tensor, torch.Tensor]:
6472
"""
6573
Estimates the principal directions of curvature (which includes normals)
@@ -88,6 +96,8 @@ def estimate_pointcloud_local_coord_frames(
8896
geometry around each point.
8997
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
9098
ensure sign consistency of the normals of neighboring points.
99+
**use_symeig_workaround**: If `True`, uses a custom eigenvalue
100+
calculation.
91101
92102
Returns:
93103
**curvatures**: The three principal curvatures of each point
@@ -133,7 +143,10 @@ def estimate_pointcloud_local_coord_frames(
133143
# eigenvectors (=principal directions) in an ascending order of their
134144
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
135145
# corresponds to the normal direction
136-
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
146+
if use_symeig_workaround:
147+
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
148+
else:
149+
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
137150

138151
# disambiguate the directions of individual principal vectors
139152
if disambiguate_directions:

tests/benchmarks/bm_points_normals.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
9+
import torch
10+
from fvcore.common.benchmark import benchmark
11+
from pytorch3d.ops import estimate_pointcloud_normals
12+
from test_points_normals import TestPCLNormals
13+
14+
15+
def to_bm(num_points, use_symeig_workaround):
16+
device = torch.device("cuda:0")
17+
points_padded, _normals = TestPCLNormals.init_spherical_pcl(
18+
num_points=num_points, device=device, use_pointclouds=False
19+
)
20+
torch.cuda.synchronize()
21+
22+
def run():
23+
estimate_pointcloud_normals(
24+
points_padded, use_symeig_workaround=use_symeig_workaround
25+
)
26+
torch.cuda.synchronize()
27+
28+
return run
29+
30+
31+
def bm_points_normals() -> None:
32+
case_grid = {
33+
"use_symeig_workaround": [True, False],
34+
"num_points": [3000, 6000],
35+
}
36+
test_cases = itertools.product(*case_grid.values())
37+
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
38+
benchmark(
39+
to_bm,
40+
"normals",
41+
kwargs_list,
42+
warmup_iters=1,
43+
)
44+
45+
46+
if __name__ == "__main__":
47+
bm_points_normals()

0 commit comments

Comments
 (0)