Skip to content

Commit dd4a35c

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
mesh rasterizer settings fix
Summary: Fix default setting of `max_faces_per_bin` and update mesh rasterization benchmark tests. The previous setting of `max_faces_per_bin` was wrong and for larger mesh sizes and batch sizes it was causing a significant slow down due to an unecessarily large intermediate tensor being created. Reviewed By: gkioxari Differential Revision: D22301819 fbshipit-source-id: d5e817f5b917fb5633c9c6a8634b6c8ff65e3508
1 parent 88f5793 commit dd4a35c

File tree

3 files changed

+49
-30
lines changed

3 files changed

+49
-30
lines changed

pytorch3d/renderer/mesh/rasterize_meshes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def rasterize_meshes(
130130
)
131131

132132
if max_faces_per_bin is None:
133-
max_faces_per_bin = int(max(10000, verts_packed.shape[0] / 5))
133+
max_faces_per_bin = int(max(10000, meshes._F / 5))
134134

135135
# pyre-fixme[16]: `_RasterizeFaceVerts` has no attribute `apply`.
136136
return _RasterizeFaceVerts.apply(

tests/bm_rasterize_meshes.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# 1: (42 verts, 80 faces)
1414
# 3: (642 verts, 1280 faces)
1515
# 4: (2562 verts, 5120 faces)
16+
# 5: (10242 verts, 20480 faces)
17+
# 6: (40962 verts, 81920 faces)
1618

1719

1820
def bm_rasterize_meshes() -> None:
@@ -22,6 +24,7 @@ def bm_rasterize_meshes() -> None:
2224
"ico_level": 0,
2325
"image_size": 10, # very slow with large image size
2426
"blur_radius": 0.0,
27+
"faces_per_pixel": 3,
2528
}
2629
]
2730
benchmark(
@@ -35,12 +38,19 @@ def bm_rasterize_meshes() -> None:
3538
num_meshes = [1]
3639
ico_level = [1]
3740
image_size = [64, 128]
38-
blur = [0.0, 1e-8, 1e-4]
39-
test_cases = product(num_meshes, ico_level, image_size, blur)
41+
blur = [1e-6]
42+
faces_per_pixel = [3, 50]
43+
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
4044
for case in test_cases:
41-
n, ic, im, b = case
45+
n, ic, im, b, f = case
4246
kwargs_list.append(
43-
{"num_meshes": n, "ico_level": ic, "image_size": im, "blur_radius": b}
47+
{
48+
"num_meshes": n,
49+
"ico_level": ic,
50+
"image_size": im,
51+
"blur_radius": b,
52+
"faces_per_pixel": f,
53+
}
4454
)
4555
benchmark(
4656
TestRasterizeMeshes.rasterize_meshes_cpu_with_init,
@@ -51,26 +61,22 @@ def bm_rasterize_meshes() -> None:
5161

5262
if torch.cuda.is_available():
5363
kwargs_list = []
54-
num_meshes = [1, 8]
55-
ico_level = [0, 1, 3, 4]
64+
num_meshes = [8, 16]
65+
ico_level = [4, 5, 6]
5666
image_size = [64, 128, 512]
57-
blur = [0.0, 1e-8, 1e-4]
58-
bin_size = [0, 8, 32]
59-
test_cases = product(num_meshes, ico_level, image_size, blur, bin_size)
60-
# only keep cases where bin_size == 0 or image_size / bin_size < 16
61-
test_cases = [
62-
elem for elem in test_cases if (elem[-1] == 0 or elem[-3] / elem[-1] < 16)
63-
]
67+
blur = [1e-6]
68+
faces_per_pixel = [50]
69+
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
70+
6471
for case in test_cases:
65-
n, ic, im, b, bn = case
72+
n, ic, im, b, f = case
6673
kwargs_list.append(
6774
{
6875
"num_meshes": n,
6976
"ico_level": ic,
7077
"image_size": im,
7178
"blur_radius": b,
72-
"bin_size": bn,
73-
"max_faces_per_bin": 200,
79+
"faces_per_pixel": f,
7480
}
7581
)
7682
benchmark(

tests/test_rasterize_meshes.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,26 +1004,42 @@ def _test_coarse_rasterize(self, device):
10041004

10051005
@staticmethod
10061006
def rasterize_meshes_python_with_init(
1007-
num_meshes: int, ico_level: int, image_size: int, blur_radius: float
1007+
num_meshes: int,
1008+
ico_level: int,
1009+
image_size: int,
1010+
blur_radius: float,
1011+
faces_per_pixel: int,
10081012
):
10091013
device = torch.device("cpu")
10101014
meshes = ico_sphere(ico_level, device)
10111015
meshes_batch = meshes.extend(num_meshes)
10121016

10131017
def rasterize():
1014-
rasterize_meshes_python(meshes_batch, image_size, blur_radius)
1018+
rasterize_meshes_python(
1019+
meshes_batch, image_size, blur_radius, faces_per_pixel
1020+
)
10151021

10161022
return rasterize
10171023

10181024
@staticmethod
10191025
def rasterize_meshes_cpu_with_init(
1020-
num_meshes: int, ico_level: int, image_size: int, blur_radius: float
1026+
num_meshes: int,
1027+
ico_level: int,
1028+
image_size: int,
1029+
blur_radius: float,
1030+
faces_per_pixel: int,
10211031
):
10221032
meshes = ico_sphere(ico_level, torch.device("cpu"))
10231033
meshes_batch = meshes.extend(num_meshes)
10241034

10251035
def rasterize():
1026-
rasterize_meshes(meshes_batch, image_size, blur_radius, bin_size=0)
1036+
rasterize_meshes(
1037+
meshes_batch,
1038+
image_size,
1039+
blur_radius,
1040+
faces_per_pixel=faces_per_pixel,
1041+
bin_size=0,
1042+
)
10271043

10281044
return rasterize
10291045

@@ -1033,18 +1049,15 @@ def rasterize_meshes_cuda_with_init(
10331049
ico_level: int,
10341050
image_size: int,
10351051
blur_radius: float,
1036-
bin_size: int,
1037-
max_faces_per_bin: int,
1052+
faces_per_pixel: int,
10381053
):
1039-
1040-
meshes = ico_sphere(ico_level, get_random_cuda_device())
1054+
device = get_random_cuda_device()
1055+
meshes = ico_sphere(ico_level, device)
10411056
meshes_batch = meshes.extend(num_meshes)
1042-
torch.cuda.synchronize()
1057+
torch.cuda.synchronize(device)
10431058

10441059
def rasterize():
1045-
rasterize_meshes(
1046-
meshes_batch, image_size, blur_radius, 8, bin_size, max_faces_per_bin
1047-
)
1048-
torch.cuda.synchronize()
1060+
rasterize_meshes(meshes_batch, image_size, blur_radius, faces_per_pixel)
1061+
torch.cuda.synchronize(device)
10491062

10501063
return rasterize

0 commit comments

Comments
 (0)