Skip to content

Commit 0ff72da

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [prototype] Minor speed and nit optimizations on Transform Classes (#6837)
Summary: * Change random generator for ColorJitter. * Move `_convert_fill_arg` from runtime to constructor. * Remove unnecessary TypeVars. * Remove unnecessary casts * Update comments. * Minor code-quality changes on Geometical Transforms. * Fixing linter and other minor fixes. * Change mitigation for mypy.` * Fixing the tests. * Fixing the tests. * Fix linter * Restore dict copy. * Handling of defaultdicts * restore int idiom * Update todo Reviewed By: YosuaMichael Differential Revision: D40755989 fbshipit-source-id: d5b475ea9a603c7a137e85db08dcd0db30195e3c
1 parent 1f72391 commit 0ff72da

File tree

8 files changed

+80
-89
lines changed

8 files changed

+80
-89
lines changed

test/test_prototype_transforms.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def test__transform(self, padding, fill, padding_mode, mocker):
389389
inpt = mocker.MagicMock(spec=features.Image)
390390
_ = transform(inpt)
391391

392-
fill = transforms.functional._geometry._convert_fill_arg(fill)
392+
fill = transforms._utils._convert_fill_arg(fill)
393393
if isinstance(padding, tuple):
394394
padding = list(padding)
395395
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@@ -405,14 +405,14 @@ def test__transform_image_mask(self, fill, mocker):
405405
_ = transform(inpt)
406406

407407
if isinstance(fill, int):
408-
fill = transforms.functional._geometry._convert_fill_arg(fill)
408+
fill = transforms._utils._convert_fill_arg(fill)
409409
calls = [
410410
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
411411
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
412412
]
413413
else:
414-
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
415-
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
414+
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
415+
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
416416
calls = [
417417
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
418418
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
@@ -466,7 +466,7 @@ def test__transform(self, fill, side_range, mocker):
466466
torch.rand(1) # random apply changes random state
467467
params = transform._get_params([inpt])
468468

469-
fill = transforms.functional._geometry._convert_fill_arg(fill)
469+
fill = transforms._utils._convert_fill_arg(fill)
470470
fn.assert_called_once_with(inpt, **params, fill=fill)
471471

472472
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
@@ -485,14 +485,14 @@ def test__transform_image_mask(self, fill, mocker):
485485
params = transform._get_params(inpt)
486486

487487
if isinstance(fill, int):
488-
fill = transforms.functional._geometry._convert_fill_arg(fill)
488+
fill = transforms._utils._convert_fill_arg(fill)
489489
calls = [
490490
mocker.call(image, **params, fill=fill),
491491
mocker.call(mask, **params, fill=fill),
492492
]
493493
else:
494-
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
495-
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
494+
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
495+
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
496496
calls = [
497497
mocker.call(image, **params, fill=fill_img),
498498
mocker.call(mask, **params, fill=fill_mask),
@@ -556,7 +556,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
556556
torch.manual_seed(12)
557557
params = transform._get_params(inpt)
558558

559-
fill = transforms.functional._geometry._convert_fill_arg(fill)
559+
fill = transforms._utils._convert_fill_arg(fill)
560560
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
561561

562562
@pytest.mark.parametrize("angle", [34, -87])
@@ -694,7 +694,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
694694
torch.manual_seed(12)
695695
params = transform._get_params([inpt])
696696

697-
fill = transforms.functional._geometry._convert_fill_arg(fill)
697+
fill = transforms._utils._convert_fill_arg(fill)
698698
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
699699

700700

@@ -939,7 +939,7 @@ def test__transform(self, distortion_scale, mocker):
939939
torch.rand(1) # random apply changes random state
940940
params = transform._get_params([inpt])
941941

942-
fill = transforms.functional._geometry._convert_fill_arg(fill)
942+
fill = transforms._utils._convert_fill_arg(fill)
943943
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
944944

945945

@@ -1009,7 +1009,7 @@ def test__transform(self, alpha, sigma, mocker):
10091009
transform._get_params = mocker.MagicMock()
10101010
_ = transform(inpt)
10111011
params = transform._get_params([inpt])
1012-
fill = transforms.functional._geometry._convert_fill_arg(fill)
1012+
fill = transforms._utils._convert_fill_arg(fill)
10131013
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
10141014

10151015

@@ -1632,7 +1632,7 @@ def test__transform(self, mocker, needs):
16321632
if not needs_crop:
16331633
assert args[0] is inpt_sentinel
16341634
assert args[1] is padding_sentinel
1635-
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel)
1635+
fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel)
16361636
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
16371637
else:
16381638
mock_pad.assert_not_called()

test/test_prototype_transforms_consistency.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -983,8 +983,6 @@ def _transform(self, inpt, params):
983983
return inpt
984984

985985
fill = self.fill[type(inpt)]
986-
fill = F._geometry._convert_fill_arg(fill)
987-
988986
return F.pad(inpt, padding=params["padding"], fill=fill)
989987

990988

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
33

44
import PIL.Image
55
import torch
@@ -11,9 +11,6 @@
1111

1212
from ._utils import _isinstance, _setup_fill_arg
1313

14-
K = TypeVar("K")
15-
V = TypeVar("V")
16-
1714

1815
class _AutoAugmentBase(Transform):
1916
def __init__(
@@ -26,7 +23,7 @@ def __init__(
2623
self.interpolation = interpolation
2724
self.fill = _setup_fill_arg(fill)
2825

29-
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
26+
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
3027
keys = tuple(dct.keys())
3128
key = keys[int(torch.randint(len(keys), ()))]
3229
return key, dct[key]
@@ -71,10 +68,9 @@ def _apply_image_or_video_transform(
7168
transform_id: str,
7269
magnitude: float,
7370
interpolation: InterpolationMode,
74-
fill: Dict[Type, features.FillType],
71+
fill: Dict[Type, features.FillTypeJIT],
7572
) -> Union[features.ImageType, features.VideoType]:
7673
fill_ = fill[type(image)]
77-
fill_ = F._geometry._convert_fill_arg(fill_)
7874

7975
if transform_id == "Identity":
8076
return image
@@ -170,9 +166,7 @@ class AutoAugment(_AutoAugmentBase):
170166
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
171167
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
172168
"Posterize": (
173-
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
174-
.round()
175-
.int(),
169+
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
176170
False,
177171
),
178172
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
@@ -327,9 +321,7 @@ class RandAugment(_AutoAugmentBase):
327321
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
328322
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
329323
"Posterize": (
330-
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
331-
.round()
332-
.int(),
324+
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
333325
False,
334326
),
335327
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
@@ -383,9 +375,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
383375
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
384376
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
385377
"Posterize": (
386-
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
387-
.round()
388-
.int(),
378+
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
389379
False,
390380
),
391381
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
@@ -430,9 +420,7 @@ class AugMix(_AutoAugmentBase):
430420
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
431421
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
432422
"Posterize": (
433-
lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
434-
.round()
435-
.int(),
423+
lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
436424
False,
437425
),
438426
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
@@ -517,7 +505,13 @@ def forward(self, *inputs: Any) -> Any:
517505
aug = self._apply_image_or_video_transform(
518506
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
519507
)
520-
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
508+
mix.add_(
509+
# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
510+
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
511+
# TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840
512+
combined_weights[:, i].reshape(batch_dims)
513+
* aug
514+
)
521515
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
522516

523517
if isinstance(orig_image_or_video, (features.Image, features.Video)):

torchvision/prototype/transforms/_color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _check_input(
5151

5252
@staticmethod
5353
def _generate_value(left: float, right: float) -> float:
54-
return float(torch.distributions.Uniform(left, right).sample())
54+
return torch.empty(1).uniform_(left, right).item()
5555

5656
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
5757
fn_idx = torch.randperm(4)

torchvision/prototype/transforms/_geometry.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,16 @@ def __init__(
223223
_check_padding_arg(padding)
224224
_check_padding_mode_arg(padding_mode)
225225

226+
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
227+
if not isinstance(padding, int):
228+
padding = list(padding)
226229
self.padding = padding
227230
self.fill = _setup_fill_arg(fill)
228231
self.padding_mode = padding_mode
229232

230233
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
231234
fill = self.fill[type(inpt)]
232-
233-
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
234-
padding = self.padding
235-
if not isinstance(padding, int):
236-
padding = list(padding)
237-
238-
fill = F._geometry._convert_fill_arg(fill)
239-
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
235+
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
240236

241237

242238
class RandomZoomOut(_RandomApplyTransform):
@@ -274,7 +270,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
274270

275271
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
276272
fill = self.fill[type(inpt)]
277-
fill = F._geometry._convert_fill_arg(fill)
278273
return F.pad(inpt, **params, fill=fill)
279274

280275

@@ -300,12 +295,11 @@ def __init__(
300295
self.center = center
301296

302297
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
303-
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
298+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
304299
return dict(angle=angle)
305300

306301
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
307302
fill = self.fill[type(inpt)]
308-
fill = F._geometry._convert_fill_arg(fill)
309303
return F.rotate(
310304
inpt,
311305
**params,
@@ -358,7 +352,7 @@ def __init__(
358352
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
359353
height, width = query_spatial_size(flat_inputs)
360354

361-
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
355+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
362356
if self.translate is not None:
363357
max_dx = float(self.translate[0] * width)
364358
max_dy = float(self.translate[1] * height)
@@ -369,22 +363,21 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
369363
translate = (0, 0)
370364

371365
if self.scale is not None:
372-
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
366+
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
373367
else:
374368
scale = 1.0
375369

376370
shear_x = shear_y = 0.0
377371
if self.shear is not None:
378-
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
372+
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
379373
if len(self.shear) == 4:
380-
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())
374+
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
381375

382376
shear = (shear_x, shear_y)
383377
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
384378

385379
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
386380
fill = self.fill[type(inpt)]
387-
fill = F._geometry._convert_fill_arg(fill)
388381
return F.affine(
389382
inpt,
390383
**params,
@@ -478,8 +471,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
478471
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
479472
if params["needs_pad"]:
480473
fill = self.fill[type(inpt)]
481-
fill = F._geometry._convert_fill_arg(fill)
482-
483474
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
484475

485476
if params["needs_crop"]:
@@ -512,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
512503

513504
half_height = height // 2
514505
half_width = width // 2
506+
bound_height = int(distortion_scale * half_height) + 1
507+
bound_width = int(distortion_scale * half_width) + 1
515508
topleft = [
516-
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
517-
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
509+
int(torch.randint(0, bound_width, size=(1,))),
510+
int(torch.randint(0, bound_height, size=(1,))),
518511
]
519512
topright = [
520-
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
521-
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
513+
int(torch.randint(width - bound_width, width, size=(1,))),
514+
int(torch.randint(0, bound_height, size=(1,))),
522515
]
523516
botright = [
524-
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
525-
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
517+
int(torch.randint(width - bound_width, width, size=(1,))),
518+
int(torch.randint(height - bound_height, height, size=(1,))),
526519
]
527520
botleft = [
528-
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
529-
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
521+
int(torch.randint(0, bound_width, size=(1,))),
522+
int(torch.randint(height - bound_height, height, size=(1,))),
530523
]
531524
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
532525
endpoints = [topleft, topright, botright, botleft]
@@ -535,7 +528,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
535528

536529
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
537530
fill = self.fill[type(inpt)]
538-
fill = F._geometry._convert_fill_arg(fill)
539531
return F.perspective(
540532
inpt,
541533
**params,
@@ -584,7 +576,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
584576

585577
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
586578
fill = self.fill[type(inpt)]
587-
fill = F._geometry._convert_fill_arg(fill)
588579
return F.elastic(
589580
inpt,
590581
**params,
@@ -855,7 +846,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
855846

856847
if params["needs_pad"]:
857848
fill = self.fill[type(inpt)]
858-
fill = F._geometry._convert_fill_arg(fill)
859849
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
860850

861851
return inpt

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, cast, Dict, Optional, Union
1+
from typing import Any, Dict, Optional, Union
22

33
import numpy as np
44
import PIL.Image
@@ -13,7 +13,7 @@ class DecodeImage(Transform):
1313
_transformed_types = (features.EncodedImage,)
1414

1515
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
16-
return cast(features.Image, F.decode_image_with_pil(inpt))
16+
return F.decode_image_with_pil(inpt) # type: ignore[no-any-return]
1717

1818

1919
class LabelToOneHot(Transform):
@@ -27,7 +27,7 @@ def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.O
2727
num_categories = self.num_categories
2828
if num_categories == -1 and inpt.categories is not None:
2929
num_categories = len(inpt.categories)
30-
output = one_hot(inpt, num_classes=num_categories)
30+
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
3131
return features.OneHotLabel(output, categories=inpt.categories)
3232

3333
def extra_repr(self) -> str:
@@ -50,7 +50,7 @@ class ToImageTensor(Transform):
5050
def _transform(
5151
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
5252
) -> features.Image:
53-
return cast(features.Image, F.to_image_tensor(inpt))
53+
return F.to_image_tensor(inpt) # type: ignore[no-any-return]
5454

5555

5656
class ToImagePIL(Transform):

0 commit comments

Comments
 (0)