Skip to content

Commit ea5df60

Browse files
Krzysztof Chalupkafacebook-github-bot
Krzysztof Chalupka
authored andcommitted
In blending, pull common functionality into get_background_color
Summary: A small refactor, originally intended for use with the splatter. Reviewed By: bottler Differential Revision: D36210393 fbshipit-source-id: b3372f7cc7690ee45dd3059b2d4be1c8dfa63180
1 parent 4372001 commit ea5df60

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

pytorch3d/renderer/blending.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from pytorch3d import _C
12+
from pytorch3d.common.datatypes import Device
1213

1314

1415
# Example functions for blending the top K colors per pixel using the outputs
@@ -37,6 +38,17 @@ class BlendParams(NamedTuple):
3738
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
3839

3940

41+
def _get_background_color(
42+
blend_params: BlendParams, device: Device, dtype=torch.float32
43+
) -> torch.Tensor:
44+
background_color_ = blend_params.background_color
45+
if isinstance(background_color_, torch.Tensor):
46+
background_color = background_color_.to(device)
47+
else:
48+
background_color = torch.tensor(background_color_, dtype=dtype, device=device)
49+
return background_color
50+
51+
4052
def hard_rgb_blend(
4153
colors: torch.Tensor, fragments, blend_params: BlendParams
4254
) -> torch.Tensor:
@@ -57,18 +69,11 @@ def hard_rgb_blend(
5769
Returns:
5870
RGBA pixel_colors: (N, H, W, 4)
5971
"""
60-
N, H, W, K = fragments.pix_to_face.shape
61-
device = fragments.pix_to_face.device
72+
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)
6273

6374
# Mask for the background.
6475
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
6576

66-
background_color_ = blend_params.background_color
67-
if isinstance(background_color_, torch.Tensor):
68-
background_color = background_color_.to(device)
69-
else:
70-
background_color = colors.new_tensor(background_color_)
71-
7277
# Find out how much background_color needs to be expanded to be used for masked_scatter.
7378
num_background_pixels = is_background.sum()
7479

@@ -182,13 +187,8 @@ def softmax_rgb_blend(
182187
"""
183188

184189
N, H, W, K = fragments.pix_to_face.shape
185-
device = fragments.pix_to_face.device
186190
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
187-
background_ = blend_params.background_color
188-
if not isinstance(background_, torch.Tensor):
189-
background = torch.tensor(background_, dtype=torch.float32, device=device)
190-
else:
191-
background = background_.to(device)
191+
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)
192192

193193
# Weight for background color
194194
eps = 1e-10
@@ -233,7 +233,7 @@ def softmax_rgb_blend(
233233

234234
# Sum: weights * textures + background color
235235
weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
236-
weighted_background = delta * background
236+
weighted_background = delta * background_color
237237
pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
238238
pixel_colors[..., 3] = 1.0 - alpha
239239

0 commit comments

Comments
 (0)