Skip to content

Commit c248b86

Browse files
committed
changes based on code-review
1 parent 9e15934 commit c248b86

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

test/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,14 +1179,14 @@ def test_adjust_gamma(self):
11791179
# test 1
11801180
y_pil = F.adjust_gamma(x_pil, 0.5)
11811181
y_np = np.array(y_pil)
1182-
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
1182+
y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
11831183
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
11841184
self.assertTrue(np.allclose(y_np, y_ans))
11851185

11861186
# test 2
11871187
y_pil = F.adjust_gamma(x_pil, 2)
11881188
y_np = np.array(y_pil)
1189-
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
1189+
y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
11901190
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
11911191
self.assertTrue(np.allclose(y_np, y_ans))
11921192

torchvision/transforms/functional.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
161161
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
162162
raise RuntimeError(msg)
163163

164-
max = torch.iinfo(dtype).max
165-
return image.mul(torch.iinfo(dtype).max).clamp(0, max).to(dtype)
164+
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
165+
# For data in the range 0-1, (float * 255).to(uint) is only 255
166+
# when float is exactly 1.0.
167+
# `max + 1 - epsilon` provides more evenly distributed mapping of
168+
# ranges of floats to ints.
169+
eps = 1e-3
170+
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
171+
return result.to(dtype)
166172
else:
167173
# int to float
168174
if dtype.is_floating_point:
@@ -779,8 +785,10 @@ def adjust_gamma(img, gamma: float, gain: float = 1):
779785
gamma larger than 1 make the shadows darker,
780786
while gamma smaller than 1 make dark regions lighter.
781787
gain (float): The constant multiplier.
788+
Returns:
789+
PIL Image or Tensor: Gamma correction adjusted image.
782790
"""
783-
if F_pil._is_pil_image(img):
791+
if not isinstance(img, torch.Tensor):
784792
return F_pil.adjust_gamma(img, gamma, gain)
785793

786794
return F_t.adjust_gamma(img, gamma, gain)

torchvision/transforms/functional_pil.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ def adjust_gamma(img, gamma, gain=1):
193193

194194
input_mode = img.mode
195195
img = img.convert('RGB')
196-
197-
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
196+
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
198197
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
199198

200199
img = img.convert(input_mode)

torchvision/transforms/functional_tensor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
208208
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
209209
210210
Args:
211-
img (Tensor): PIL Image to be adjusted.
211+
img (Tensor): Tensor of RBG values to be adjusted.
212212
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
213213
gamma larger than 1 make the shadows darker,
214214
while gamma smaller than 1 make dark regions lighter.
@@ -223,12 +223,15 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
223223

224224
result = img
225225
dtype = img.dtype
226-
if torch.is_floating_point(img):
227-
return gain * result ** gamma
226+
if not torch.is_floating_point(img):
227+
result = result / 255.0
228228

229-
result = 255.0 * gain * (result / 255.0) ** gamma
230-
# PIL clamps, to(torch.uint8) would wrap
231-
result = result.clamp(0, 255).to(dtype)
229+
result = (gain * result ** gamma).clamp(0, 1)
230+
231+
if result.dtype != dtype:
232+
eps = 1e-3
233+
result = (255 + 1.0 - eps) * result
234+
result = result.to(dtype)
232235
return result
233236

234237

0 commit comments

Comments
 (0)