9
9
10
10
import torch
11
11
from pytorch3d import _C
12
+ from pytorch3d .common .datatypes import Device
12
13
13
14
14
15
# Example functions for blending the top K colors per pixel using the outputs
@@ -37,6 +38,17 @@ class BlendParams(NamedTuple):
37
38
background_color : Union [torch .Tensor , Sequence [float ]] = (1.0 , 1.0 , 1.0 )
38
39
39
40
41
+ def _get_background_color (
42
+ blend_params : BlendParams , device : Device , dtype = torch .float32
43
+ ) -> torch .Tensor :
44
+ background_color_ = blend_params .background_color
45
+ if isinstance (background_color_ , torch .Tensor ):
46
+ background_color = background_color_ .to (device )
47
+ else :
48
+ background_color = torch .tensor (background_color_ , dtype = dtype , device = device )
49
+ return background_color
50
+
51
+
40
52
def hard_rgb_blend (
41
53
colors : torch .Tensor , fragments , blend_params : BlendParams
42
54
) -> torch .Tensor :
@@ -57,18 +69,11 @@ def hard_rgb_blend(
57
69
Returns:
58
70
RGBA pixel_colors: (N, H, W, 4)
59
71
"""
60
- N , H , W , K = fragments .pix_to_face .shape
61
- device = fragments .pix_to_face .device
72
+ background_color = _get_background_color (blend_params , fragments .pix_to_face .device )
62
73
63
74
# Mask for the background.
64
75
is_background = fragments .pix_to_face [..., 0 ] < 0 # (N, H, W)
65
76
66
- background_color_ = blend_params .background_color
67
- if isinstance (background_color_ , torch .Tensor ):
68
- background_color = background_color_ .to (device )
69
- else :
70
- background_color = colors .new_tensor (background_color_ )
71
-
72
77
# Find out how much background_color needs to be expanded to be used for masked_scatter.
73
78
num_background_pixels = is_background .sum ()
74
79
@@ -182,13 +187,8 @@ def softmax_rgb_blend(
182
187
"""
183
188
184
189
N , H , W , K = fragments .pix_to_face .shape
185
- device = fragments .pix_to_face .device
186
190
pixel_colors = torch .ones ((N , H , W , 4 ), dtype = colors .dtype , device = colors .device )
187
- background_ = blend_params .background_color
188
- if not isinstance (background_ , torch .Tensor ):
189
- background = torch .tensor (background_ , dtype = torch .float32 , device = device )
190
- else :
191
- background = background_ .to (device )
191
+ background_color = _get_background_color (blend_params , fragments .pix_to_face .device )
192
192
193
193
# Weight for background color
194
194
eps = 1e-10
@@ -233,7 +233,7 @@ def softmax_rgb_blend(
233
233
234
234
# Sum: weights * textures + background color
235
235
weighted_colors = (weights_num [..., None ] * colors ).sum (dim = - 2 )
236
- weighted_background = delta * background
236
+ weighted_background = delta * background_color
237
237
pixel_colors [..., :3 ] = (weighted_colors + weighted_background ) / denom
238
238
pixel_colors [..., 3 ] = 1.0 - alpha
239
239
0 commit comments