Skip to content

Commit 719eb7c

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [prototype] Minor improvements on functional (#6832)
Summary: * Minor improvements on functional. * Restore `_split_alpha`. * Revert "Restore `_split_alpha`." This reverts commit 2286120. Reviewed By: YosuaMichael Differential Revision: D40722902 fbshipit-source-id: 3a574939365abd1b74ed3a558b4354b1c40fc883
1 parent 64cbc3a commit 719eb7c

File tree

3 files changed

+8
-22
lines changed

3 files changed

+8
-22
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
188188
h, s, v = img.unbind(dim=-3)
189189
h6 = h * 6
190190
i = torch.floor(h6)
191-
f = (h6) - i
191+
f = h6 - i
192192
i = i.to(dtype=torch.int32)
193193

194194
p = (v * (1.0 - s)).clamp_(0.0, 1.0)
@@ -210,9 +210,6 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
210210
if not (-0.5 <= hue_factor <= 0.5):
211211
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
212212

213-
if not (isinstance(image, torch.Tensor)):
214-
raise TypeError("Input img should be Tensor image")
215-
216213
c = get_num_channels_image_tensor(image)
217214

218215
if c not in [1, 3]:
@@ -258,9 +255,6 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input
258255

259256

260257
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
261-
if not (isinstance(image, torch.Tensor)):
262-
raise TypeError("Input img should be Tensor image")
263-
264258
if gamma < 0:
265259
raise ValueError("Gamma should be a non-negative real number")
266260

@@ -337,10 +331,6 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
337331

338332

339333
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
340-
341-
if not (isinstance(image, torch.Tensor)):
342-
raise TypeError("Input img should be Tensor image")
343-
344334
c = get_num_channels_image_tensor(image)
345335

346336
if c not in [1, 3]:

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,8 @@ def clamp_bounding_box(
183183
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format)
184184

185185

186-
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
187-
return image[..., :-1, :, :], image[..., -1:, :, :]
188-
189-
190186
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
191-
image, alpha = _split_alpha(image)
187+
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
192188
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
193189
raise RuntimeError(
194190
"Stripping the alpha channel if it contains values other than the max value is not supported."
@@ -237,7 +233,7 @@ def convert_color_space_image_tensor(
237233
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
238234
return _gray_to_rgb(_strip_alpha(image))
239235
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
240-
image, alpha = _split_alpha(image)
236+
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
241237
return _add_alpha(_gray_to_rgb(image), alpha)
242238
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
243239
return _rgb_to_gray(image)
@@ -248,7 +244,7 @@ def convert_color_space_image_tensor(
248244
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
249245
return _rgb_to_gray(_strip_alpha(image))
250246
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
251-
image, alpha = _split_alpha(image)
247+
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
252248
return _add_alpha(_rgb_to_gray(image), alpha)
253249
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
254250
return _strip_alpha(image)

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,18 @@ def normalize(
6767
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
6868

6969

70-
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
70+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
7171
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
72-
x = torch.linspace(-lim, lim, steps=kernel_size)
72+
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
7373
kernel1d = torch.softmax(-x.pow_(2), dim=0)
7474
return kernel1d
7575

7676

7777
def _get_gaussian_kernel2d(
7878
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
7979
) -> torch.Tensor:
80-
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
81-
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
80+
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
81+
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
8282
kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
8383
return kernel2d
8484

0 commit comments

Comments
 (0)