Skip to content

Commit b818644

Browse files
authored
Merge branch 'main' into replace
2 parents be6764d + fbc8ea4 commit b818644

File tree

9 files changed

+39
-59
lines changed

9 files changed

+39
-59
lines changed

.circleci/unittest/linux/scripts/install.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ else
2121
fi
2222
echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION"
2323
version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
24-
cudatoolkit="cudatoolkit=${version}"
24+
cudatoolkit="nvidia::cudatoolkit=${version}"
2525
fi
2626

2727
case "$(uname -s)" in
@@ -33,7 +33,7 @@ printf "Installing PyTorch with %s\n" "${cudatoolkit}"
3333
if [ "${os}" == "MacOSX" ]; then
3434
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest
3535
else
36-
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" pytest
36+
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c nvidia "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" pytest
3737
fi
3838

3939
printf "* Installing torchvision\n"

packaging/build_cmake.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ else
4242
PYTORCH_MUTEX_CONSTRAINT=''
4343
fi
4444

45-
conda install -yq \pytorch=$PYTORCH_VERSION $CONDA_CUDATOOLKIT_CONSTRAINT $PYTORCH_MUTEX_CONSTRAINT $MKL_CONSTRAINT numpy -c "pytorch-${UPLOAD_CHANNEL}"
45+
conda install -yq \pytorch=$PYTORCH_VERSION $CONDA_CUDATOOLKIT_CONSTRAINT $PYTORCH_MUTEX_CONSTRAINT $MKL_CONSTRAINT numpy -c nvidia -c "pytorch-${UPLOAD_CHANNEL}"
4646
TORCH_PATH=$(dirname $(python -c "import torch; print(torch.__file__)"))
4747

4848
if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then

packaging/build_conda.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ if [[ "$CU_VERSION" == cu115 ]]; then
1818
export CUDATOOLKIT_CHANNEL="conda-forge"
1919
fi
2020

21-
conda build -c defaults -c $CUDATOOLKIT_CHANNEL $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchvision
21+
conda build -c $CUDATOOLKIT_CHANNEL -c defaults $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchvision

test/test_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_decode_png(img_path, pil_mode, mode):
158158

159159
img_pil = normalize_dimensions(img_pil)
160160

161-
if "16" in img_path:
161+
if img_path.endswith("16.png"):
162162
# 16 bits image decoding is supported, but only as a private API
163163
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
164164
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):

test/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def test_mobilenet_norm_layer(model_fn):
406406
assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules())
407407

408408
def get_gn(num_channels):
409-
return nn.GroupNorm(32, num_channels)
409+
return nn.GroupNorm(1, num_channels)
410410

411411
model = model_fn(norm_layer=get_gn)
412412
assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))

torchvision/prototype/transforms/_augment.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
99

10+
from ._transform import _RandomApplyTransform
1011
from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor
1112

1213

13-
class RandomErasing(Transform):
14+
class RandomErasing(_RandomApplyTransform):
1415
def __init__(
1516
self,
1617
p: float = 0.5,
1718
scale: Tuple[float, float] = (0.02, 0.33),
1819
ratio: Tuple[float, float] = (0.3, 3.3),
1920
value: float = 0,
2021
):
21-
super().__init__()
22+
super().__init__(p=p)
2223
if not isinstance(value, (numbers.Number, str, tuple, list)):
2324
raise TypeError("Argument value should be either a number or str or a sequence")
2425
if isinstance(value, str) and value != "random":
@@ -31,9 +32,6 @@ def __init__(
3132
warnings.warn("Scale and ratio should be of kind (min, max)")
3233
if scale[0] < 0 or scale[1] > 1:
3334
raise ValueError("Scale should be between 0 and 1")
34-
if p < 0 or p > 1:
35-
raise ValueError("Random erasing probability should be between 0 and 1")
36-
self.p = p
3735
self.scale = scale
3836
self.ratio = ratio
3937
self.value = value
@@ -99,8 +97,6 @@ def forward(self, *inputs: Any) -> Any:
9997
sample = inputs if len(inputs) > 1 else inputs[0]
10098
if has_any(sample, features.BoundingBox, features.SegmentationMask):
10199
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
102-
elif torch.rand(1) >= self.p:
103-
return sample
104100

105101
return super().forward(sample)
106102

torchvision/prototype/transforms/_container.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Any, Optional, List
1+
from typing import Any, Optional, List, Dict
22

33
import torch
4+
from torchvision.prototype.transforms import Transform
45

5-
from ._transform import Transform
6+
from ._transform import _RandomApplyTransform
67

78

89
class Compose(Transform):
@@ -19,18 +20,13 @@ def forward(self, *inputs: Any) -> Any:
1920
return sample
2021

2122

22-
class RandomApply(Transform):
23+
class RandomApply(_RandomApplyTransform):
2324
def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
24-
super().__init__()
25+
super().__init__(p=p)
2526
self.transform = transform
26-
self.p = p
27-
28-
def forward(self, *inputs: Any) -> Any:
29-
sample = inputs if len(inputs) > 1 else inputs[0]
30-
if float(torch.rand(())) < self.p:
31-
return sample
3227

33-
return self.transform(sample)
28+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
29+
return self.transform(input)
3430

3531
def extra_repr(self) -> str:
3632
return f"p={self.p}"

torchvision/prototype/transforms/_geometry.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,11 @@
1212
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
1313
from typing_extensions import Literal
1414

15+
from ._transform import _RandomApplyTransform
1516
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
1617

1718

18-
class RandomHorizontalFlip(Transform):
19-
def __init__(self, p: float = 0.5) -> None:
20-
super().__init__()
21-
self.p = p
22-
23-
def forward(self, *inputs: Any) -> Any:
24-
sample = inputs if len(inputs) > 1 else inputs[0]
25-
if torch.rand(1) >= self.p:
26-
return sample
27-
28-
return super().forward(sample)
29-
19+
class RandomHorizontalFlip(_RandomApplyTransform):
3020
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
3121
if isinstance(input, features.Image):
3222
output = F.horizontal_flip_image_tensor(input)
@@ -45,18 +35,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
4535
return input
4636

4737

48-
class RandomVerticalFlip(Transform):
49-
def __init__(self, p: float = 0.5) -> None:
50-
super().__init__()
51-
self.p = p
52-
53-
def forward(self, *inputs: Any) -> Any:
54-
sample = inputs if len(inputs) > 1 else inputs[0]
55-
if torch.rand(1) > self.p:
56-
return sample
57-
58-
return super().forward(sample)
59-
38+
class RandomVerticalFlip(_RandomApplyTransform):
6039
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
6140
if isinstance(input, features.Image):
6241
output = F.vertical_flip_image_tensor(input)
@@ -371,11 +350,11 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
371350
return input
372351

373352

374-
class RandomZoomOut(Transform):
353+
class RandomZoomOut(_RandomApplyTransform):
375354
def __init__(
376355
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
377356
) -> None:
378-
super().__init__()
357+
super().__init__(p=p)
379358

380359
if fill is None:
381360
fill = 0.0
@@ -385,8 +364,6 @@ def __init__(
385364
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
386365
raise ValueError(f"Invalid canvas side range provided {side_range}.")
387366

388-
self.p = p
389-
390367
def _get_params(self, sample: Any) -> Dict[str, Any]:
391368
image = query_image(sample)
392369
orig_c, orig_h, orig_w = get_image_dimensions(image)
@@ -411,10 +388,3 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
411388
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
412389
transform = Pad(**params, padding_mode="constant")
413390
return transform(input)
414-
415-
def forward(self, *inputs: Any) -> Any:
416-
sample = inputs if len(inputs) > 1 else inputs[0]
417-
if torch.rand(1) >= self.p:
418-
return sample
419-
420-
return super().forward(sample)

torchvision/prototype/transforms/_transform.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
from typing import Any, Dict
44

5+
import torch
56
from torch import nn
67
from torchvision.prototype.utils._internal import apply_recursively
78
from torchvision.utils import _log_api_usage_once
@@ -34,3 +35,20 @@ def extra_repr(self) -> str:
3435
extra.append(f"{name}={value}")
3536

3637
return ", ".join(extra)
38+
39+
40+
class _RandomApplyTransform(Transform):
41+
def __init__(self, *, p: float = 0.5) -> None:
42+
if not (0.0 <= p <= 1.0):
43+
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
44+
45+
super().__init__()
46+
self.p = p
47+
48+
def forward(self, *inputs: Any) -> Any:
49+
sample = inputs if len(inputs) > 1 else inputs[0]
50+
51+
if torch.rand(1) >= self.p:
52+
return sample
53+
54+
return super().forward(sample)

0 commit comments

Comments
 (0)