Skip to content

Commit 5f702bf

Browse files
committed
add torchscriptable adjust_gamma transform
#1375
1 parent 892d0ef commit 5f702bf

File tree

4 files changed

+111
-17
lines changed

4 files changed

+111
-17
lines changed

test/test_functional_tensor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def _create_data(self, height=3, width=3, channels=3):
2323

2424
def compareTensorToPIL(self, tensor, pil_image, msg=None):
2525
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
26+
if msg is None:
27+
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
2628
self.assertTrue(tensor.equal(pil_tensor), msg)
2729

2830
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
@@ -293,6 +295,33 @@ def test_pad(self):
293295
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
294296
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
295297

298+
def test_adjust_gamma(self):
299+
script_fn = torch.jit.script(F_t.adjust_gamma)
300+
tensor, pil_img = self._create_data(26, 36)
301+
302+
for dt in [torch.float64, torch.float32, None]:
303+
304+
if dt is not None:
305+
tensor = F.convert_image_dtype(tensor, dt)
306+
307+
gammas = [0.8, 1.0, 1.2]
308+
gains = [0.7, 1.0, 1.3]
309+
for gamma, gain in zip(gammas, gains):
310+
311+
adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
312+
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
313+
scripted_result = script_fn(tensor, gamma, gain)
314+
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
315+
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
316+
317+
rbg_tensor = adjusted_tensor
318+
if adjusted_tensor.dtype != torch.uint8:
319+
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
320+
321+
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
322+
323+
self.assertTrue(adjusted_tensor.equal(scripted_result))
324+
296325
def test_resize(self):
297326
script_fn = torch.jit.script(F_t.resize)
298327
tensor, pil_img = self._create_data(26, 36)

torchvision/transforms/functional.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
161161
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
162162
raise RuntimeError(msg)
163163

164-
eps = 1e-3
165-
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
164+
max = torch.iinfo(dtype).max
165+
return image.mul(torch.iinfo(dtype).max).clamp(0, max).to(dtype)
166166
else:
167167
# int to float
168168
if dtype.is_floating_point:
@@ -760,7 +760,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
760760
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
761761

762762

763-
def adjust_gamma(img, gamma, gain=1):
763+
def adjust_gamma(img, gamma: float, gain: float = 1):
764764
r"""Perform gamma correction on an image.
765765
766766
Also known as Power Law Transform. Intensities in RGB mode are adjusted
@@ -774,26 +774,16 @@ def adjust_gamma(img, gamma, gain=1):
774774
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
775775
776776
Args:
777-
img (PIL Image): PIL Image to be adjusted.
777+
img (PIL Image or Tensor): PIL Image to be adjusted.
778778
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
779779
gamma larger than 1 make the shadows darker,
780780
while gamma smaller than 1 make dark regions lighter.
781781
gain (float): The constant multiplier.
782782
"""
783-
if not F_pil._is_pil_image(img):
784-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
785-
786-
if gamma < 0:
787-
raise ValueError('Gamma should be a non-negative real number')
783+
if F_pil._is_pil_image(img):
784+
return F_pil.adjust_gamma(img, gamma, gain)
788785

789-
input_mode = img.mode
790-
img = img.convert('RGB')
791-
792-
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
793-
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
794-
795-
img = img.convert(input_mode)
796-
return img
786+
return F_t.adjust_gamma(img, gamma, gain)
797787

798788

799789
def rotate(img, angle, resample=False, expand=False, center=None, fill=None):

torchvision/transforms/functional_pil.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,43 @@ def adjust_hue(img, hue_factor):
164164
return img
165165

166166

167+
@torch.jit.unused
168+
def adjust_gamma(img, gamma, gain=1):
169+
r"""Perform gamma correction on an image.
170+
171+
Also known as Power Law Transform. Intensities in RGB mode are adjusted
172+
based on the following equation:
173+
174+
.. math::
175+
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
176+
177+
See `Gamma Correction`_ for more details.
178+
179+
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
180+
181+
Args:
182+
img (PIL Image): PIL Image to be adjusted.
183+
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
184+
gamma larger than 1 make the shadows darker,
185+
while gamma smaller than 1 make dark regions lighter.
186+
gain (float): The constant multiplier.
187+
"""
188+
if not _is_pil_image(img):
189+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
190+
191+
if gamma < 0:
192+
raise ValueError('Gamma should be a non-negative real number')
193+
194+
input_mode = img.mode
195+
img = img.convert('RGB')
196+
197+
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
198+
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
199+
200+
img = img.convert(input_mode)
201+
return img
202+
203+
167204
@torch.jit.unused
168205
def pad(img, padding, fill=0, padding_mode="constant"):
169206
r"""Pad the given PIL.Image on all sides with the given "pad" value.

torchvision/transforms/functional_tensor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,44 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
194194
return _blend(img, rgb_to_grayscale(img), saturation_factor)
195195

196196

197+
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
198+
r"""Adjust gamma of an RGB image.
199+
200+
Also known as Power Law Transform. Intensities in RGB mode are adjusted
201+
based on the following equation:
202+
203+
.. math::
204+
`I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`
205+
206+
See `Gamma Correction`_ for more details.
207+
208+
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
209+
210+
Args:
211+
img (Tensor): PIL Image to be adjusted.
212+
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
213+
gamma larger than 1 make the shadows darker,
214+
while gamma smaller than 1 make dark regions lighter.
215+
gain (float): The constant multiplier.
216+
"""
217+
218+
if not isinstance(img, torch.Tensor):
219+
raise TypeError('img should be a Tensor. Got {}'.format(type(img)))
220+
221+
if gamma < 0:
222+
raise ValueError('Gamma should be a non-negative real number')
223+
224+
result = img
225+
dtype = img.dtype
226+
if torch.is_floating_point(img):
227+
return gain * result ** gamma
228+
229+
result = 255.0 * gain * (result / 255.0) ** gamma
230+
# PIL clamps, to(torch.uint8) would wrap
231+
result = result.clamp(0, 255).to(dtype)
232+
return result
233+
234+
197235
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
198236
"""Crop the Image Tensor and resize it to desired size.
199237

0 commit comments

Comments
 (0)