Skip to content

Commit 5852b74

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Softmax blending small fix
Summary: Small fix to the softmax blending function. To avoid overflow in the exponential for the softmax, the exponent is shifted by the maximum value. In the final calculation of the color there is a weighted sum between the pixel color and the background color - in order for the sum to be correct, the background color also needs to be handled in the same way witt the shifted exponent. Reviewed By: gkioxari Differential Revision: D23148301 fbshipit-source-id: 86066586ee7d3ce7bd4a2076b12ce191fbd151a7
1 parent 8e9ff15 commit 5852b74

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

pytorch3d/renderer/blending.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from typing import NamedTuple, Sequence
55

6-
import numpy as np
76
import torch
87

98
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
@@ -162,9 +161,8 @@ def softmax_rgb_blend(
162161
if not torch.is_tensor(background):
163162
background = torch.tensor(background, dtype=torch.float32, device=device)
164163

165-
# Background color
166-
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
167-
delta = torch.tensor(delta, device=device)
164+
# Weight for background color
165+
eps = 1e-10
168166

169167
# Mask for padded pixels.
170168
mask = fragments.pix_to_face >= 0
@@ -189,15 +187,18 @@ def softmax_rgb_blend(
189187
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
190188
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
191189

190+
# Also apply exp normalize trick for the background color weight.
191+
# Clamp to ensure delta is never 0.
192+
delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps)
193+
192194
# Normalize weights.
193195
# weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
194196
denom = weights_num.sum(dim=-1)[..., None] + delta
195-
weights = weights_num / denom
196197

197198
# Sum: weights * textures + background color
198-
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
199-
weighted_background = (delta / denom) * background
200-
pixel_colors[..., :3] = weighted_colors + weighted_background
199+
weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
200+
weighted_background = delta * background
201+
pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
201202
pixel_colors[..., 3] = 1.0 - alpha
202203

203204
return pixel_colors

tests/test_blending.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import unittest
44

5-
import numpy as np
65
import torch
76
from common_testing import TestCaseMixin
87
from pytorch3d.renderer.blending import (
@@ -97,21 +96,18 @@ def softmax_blend_naive(colors, fragments, blend_params):
9796
# Near and far clipping planes
9897
zfar = 100.0
9998
znear = 1.0
99+
eps = 1e-10
100100

101101
bk_color = blend_params.background_color
102102
if not torch.is_tensor(bk_color):
103103
bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device)
104104

105-
# Background color component
106-
delta = np.exp(1e-10 / gamma) * 1e-10
107-
delta = torch.tensor(delta).to(device=device)
108-
109105
for n in range(N):
110106
for h in range(H):
111107
for w in range(W):
112108
alpha = 1.0
113109
weights_k = torch.zeros(K, device=device)
114-
zmax = 0.0
110+
zmax = torch.tensor(0.0, device=device)
115111

116112
# Loop over K to find max z.
117113
for k in range(K):
@@ -129,11 +125,13 @@ def softmax_blend_naive(colors, fragments, blend_params):
129125
alpha *= 1.0 - prob # cumulative product
130126
weights_k[k] = prob * torch.exp((zinv - zmax) / gamma)
131127

128+
# Clamp to ensure delta is never 0
129+
delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps)
130+
delta = delta.to(device)
132131
denom = weights_k.sum() + delta
133-
weights = weights_k / denom
134-
cols = (weights[..., None] * colors[n, h, w, :, :]).sum(dim=0)
135-
pixel_colors[n, h, w, :3] = cols
136-
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
132+
cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0)
133+
pixel_colors[n, h, w, :3] = cols + delta * bk_color
134+
pixel_colors[n, h, w, :3] /= denom
137135
pixel_colors[n, h, w, 3] = 1.0 - alpha
138136

139137
return pixel_colors
@@ -160,6 +158,7 @@ def _compare_impls(
160158

161159
(out2 * grad_out).sum().backward()
162160
self.assertTrue(hasattr(grad_var2, "grad"))
161+
163162
self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
164163

165164
def test_hard_rgb_blend(self):

0 commit comments

Comments
 (0)