Skip to content

Commit ebea484

Browse files
authored
Merge branch 'main' into patch-1
2 parents acb2bbd + b5c7443 commit ebea484

File tree

12 files changed

+166
-40
lines changed

12 files changed

+166
-40
lines changed

.github/scripts/setup-env.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ echo '::endgroup::'
4545

4646
if [[ "${OS_TYPE}" == windows && "${GPU_ARCH_TYPE}" == cuda ]]; then
4747
echo '::group::Install VisualStudio CUDA extensions on Windows'
48-
if [[ "${VC_YEAR:-}" == "2022" ]]; then
49-
TARGET_DIR="/c/Program Files (x86)/Microsoft Visual Studio/2022/BuildTools/MSBuild/Microsoft/VC/v170/BuildCustomizations"
50-
else
51-
TARGET_DIR="/c/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools/MSBuild/Microsoft/VC/v160/BuildCustomizations"
52-
fi
48+
TARGET_DIR="/c/Program Files (x86)/Microsoft Visual Studio/2022/BuildTools/MSBuild/Microsoft/VC/v170/BuildCustomizations"
5349
mkdir -p "${TARGET_DIR}"
5450
cp -r "${CUDA_HOME}/MSBuildExtensions/"* "${TARGET_DIR}"
5551
echo '::endgroup::'

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ jobs:
102102
set -euxo pipefail
103103
104104
export PYTHON_VERSION=${{ matrix.python-version }}
105-
export VC_YEAR=2019
105+
export VC_YEAR=2022
106106
export VSDEVCMD_ARGS=""
107107
export GPU_ARCH_TYPE=${{ matrix.gpu-arch-type }}
108108
export GPU_ARCH_VERSION=${{ matrix.gpu-arch-version }}

packaging/windows/internal/vc_env_helper.bat

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,6 @@
22

33
set VC_VERSION_LOWER=17
44
set VC_VERSION_UPPER=18
5-
if "%VC_YEAR%" == "2019" (
6-
set VC_VERSION_LOWER=16
7-
set VC_VERSION_UPPER=17
8-
)
9-
if "%VC_YEAR%" == "2017" (
10-
set VC_VERSION_LOWER=15
11-
set VC_VERSION_UPPER=16
12-
)
135

146
for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do (
157
if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" (

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_version():
7979

8080
def write_version_file(version, sha):
8181
# Exists for BC, probably completely useless.
82-
with open(ROOT_DIR / "torchvision/version.py", "w") as f:
82+
with open(ROOT_DIR / "torchvision" / "version.py", "w") as f:
8383
f.write(f"__version__ = '{version}'\n")
8484
f.write(f"git_version = {repr(sha)}\n")
8585
f.write("from torchvision.extension import _check_cuda_version\n")
@@ -194,7 +194,7 @@ def make_C_extension():
194194

195195
def find_libpng():
196196
# Returns (found, include dir, library dir, library name)
197-
if sys.platform in ("linux", "darwin"):
197+
if sys.platform in ("linux", "darwin", "aix"):
198198
libpng_config = shutil.which("libpng-config")
199199
if libpng_config is None:
200200
warnings.warn("libpng-config not found")

test/test_datasets.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config):
532532
self._create_bbox_txt(base_folder, num_images)
533533
self._create_landmarks_txt(base_folder, num_images)
534534

535-
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
535+
num_samples = num_images_per_split.get(config["split"], 0) if isinstance(config["split"], str) else 0
536+
return dict(num_examples=num_samples, attr_names=attr_names)
536537

537538
def _create_split_txt(self, root):
538539
num_images_per_split = dict(train=4, valid=3, test=2)
@@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self):
635636
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
636637
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
637638

639+
def test_invalid_split_list(self):
640+
with pytest.raises(ValueError, match="Expected type str for argument split, but got type <class 'list'>."):
641+
with self.create_dataset(split=[1]):
642+
pass
643+
644+
def test_invalid_split_int(self):
645+
with pytest.raises(ValueError, match="Expected type str for argument split, but got type <class 'int'>."):
646+
with self.create_dataset(split=1):
647+
pass
648+
649+
def test_invalid_split_value(self):
650+
with pytest.raises(
651+
ValueError,
652+
match="Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}.".format(
653+
value="invalid",
654+
arg="split",
655+
valid_values=("train", "valid", "test", "all"),
656+
),
657+
):
658+
with self.create_dataset(split="invalid"):
659+
pass
660+
638661

639662
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
640663
DATASET_CLASS = datasets.VOCSegmentation

torchvision/datasets/celeba.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ def __init__(
9393
"test": 2,
9494
"all": None,
9595
}
96-
split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
96+
split_ = split_map[
97+
verify_str_arg(
98+
split.lower() if isinstance(split, str) else split,
99+
"split",
100+
("train", "valid", "test", "all"),
101+
)
102+
]
97103
splits = self._load_csv("list_eval_partition.txt")
98104
identity = self._load_csv("identity_CelebA.txt")
99105
bbox = self._load_csv("list_bbox_celeba.txt", header=1)

torchvision/datasets/flowers102.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,108 @@ def download(self):
112112
for id in ["label", "setid"]:
113113
filename, md5 = self._file_dict[id]
114114
download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
115+
116+
classes = [
117+
"pink primrose",
118+
"hard-leaved pocket orchid",
119+
"canterbury bells",
120+
"sweet pea",
121+
"english marigold",
122+
"tiger lily",
123+
"moon orchid",
124+
"bird of paradise",
125+
"monkshood",
126+
"globe thistle",
127+
"snapdragon",
128+
"colt's foot",
129+
"king protea",
130+
"spear thistle",
131+
"yellow iris",
132+
"globe-flower",
133+
"purple coneflower",
134+
"peruvian lily",
135+
"balloon flower",
136+
"giant white arum lily",
137+
"fire lily",
138+
"pincushion flower",
139+
"fritillary",
140+
"red ginger",
141+
"grape hyacinth",
142+
"corn poppy",
143+
"prince of wales feathers",
144+
"stemless gentian",
145+
"artichoke",
146+
"sweet william",
147+
"carnation",
148+
"garden phlox",
149+
"love in the mist",
150+
"mexican aster",
151+
"alpine sea holly",
152+
"ruby-lipped cattleya",
153+
"cape flower",
154+
"great masterwort",
155+
"siam tulip",
156+
"lenten rose",
157+
"barbeton daisy",
158+
"daffodil",
159+
"sword lily",
160+
"poinsettia",
161+
"bolero deep blue",
162+
"wallflower",
163+
"marigold",
164+
"buttercup",
165+
"oxeye daisy",
166+
"common dandelion",
167+
"petunia",
168+
"wild pansy",
169+
"primula",
170+
"sunflower",
171+
"pelargonium",
172+
"bishop of llandaff",
173+
"gaura",
174+
"geranium",
175+
"orange dahlia",
176+
"pink-yellow dahlia?",
177+
"cautleya spicata",
178+
"japanese anemone",
179+
"black-eyed susan",
180+
"silverbush",
181+
"californian poppy",
182+
"osteospermum",
183+
"spring crocus",
184+
"bearded iris",
185+
"windflower",
186+
"tree poppy",
187+
"gazania",
188+
"azalea",
189+
"water lily",
190+
"rose",
191+
"thorn apple",
192+
"morning glory",
193+
"passion flower",
194+
"lotus",
195+
"toad lily",
196+
"anthurium",
197+
"frangipani",
198+
"clematis",
199+
"hibiscus",
200+
"columbine",
201+
"desert-rose",
202+
"tree mallow",
203+
"magnolia",
204+
"cyclamen",
205+
"watercress",
206+
"canna lily",
207+
"hippeastrum",
208+
"bee balm",
209+
"ball moss",
210+
"foxglove",
211+
"bougainvillea",
212+
"camellia",
213+
"mallow",
214+
"mexican petunia",
215+
"bromelia",
216+
"blanket flower",
217+
"trumpet creeper",
218+
"blackberry lily",
219+
]

torchvision/datasets/mnist.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@ class MNIST(VisionDataset):
2525
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
2626
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
2727
otherwise from ``t10k-images-idx3-ubyte``.
28-
download (bool, optional): If True, downloads the dataset from the internet and
29-
puts it in root directory. If dataset is already downloaded, it is not
30-
downloaded again.
3128
transform (callable, optional): A function/transform that takes in a PIL image
3229
and returns a transformed version. E.g, ``transforms.RandomCrop``
3330
target_transform (callable, optional): A function/transform that takes in the
3431
target and transforms it.
32+
download (bool, optional): If True, downloads the dataset from the internet and
33+
puts it in root directory. If dataset is already downloaded, it is not
34+
downloaded again.
3535
"""
3636

3737
mirrors = [
38-
"http://yann.lecun.com/exdb/mnist/",
3938
"https://ossci-datasets.s3.amazonaws.com/mnist/",
39+
"http://yann.lecun.com/exdb/mnist/",
4040
]
4141

4242
resources = [
@@ -209,13 +209,13 @@ class FashionMNIST(MNIST):
209209
and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
210210
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
211211
otherwise from ``t10k-images-idx3-ubyte``.
212-
download (bool, optional): If True, downloads the dataset from the internet and
213-
puts it in root directory. If dataset is already downloaded, it is not
214-
downloaded again.
215212
transform (callable, optional): A function/transform that takes in a PIL image
216213
and returns a transformed version. E.g, ``transforms.RandomCrop``
217214
target_transform (callable, optional): A function/transform that takes in the
218215
target and transforms it.
216+
download (bool, optional): If True, downloads the dataset from the internet and
217+
puts it in root directory. If dataset is already downloaded, it is not
218+
downloaded again.
219219
"""
220220

221221
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
@@ -237,13 +237,13 @@ class KMNIST(MNIST):
237237
and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
238238
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
239239
otherwise from ``t10k-images-idx3-ubyte``.
240-
download (bool, optional): If True, downloads the dataset from the internet and
241-
puts it in root directory. If dataset is already downloaded, it is not
242-
downloaded again.
243240
transform (callable, optional): A function/transform that takes in a PIL image
244241
and returns a transformed version. E.g, ``transforms.RandomCrop``
245242
target_transform (callable, optional): A function/transform that takes in the
246243
target and transforms it.
244+
download (bool, optional): If True, downloads the dataset from the internet and
245+
puts it in root directory. If dataset is already downloaded, it is not
246+
downloaded again.
247247
"""
248248

249249
mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
@@ -358,6 +358,9 @@ class QMNIST(MNIST):
358358
for each example is class number (for compatibility with
359359
the MNIST dataloader) or a torch vector containing the
360360
full qmnist information. Default=True.
361+
train (bool,optional,compatibility): When argument 'what' is
362+
not specified, this boolean decides whether to load the
363+
training set or the testing set. Default: True.
361364
download (bool, optional): If True, downloads the dataset from
362365
the internet and puts it in root directory. If dataset is
363366
already downloaded, it is not downloaded again.
@@ -366,9 +369,6 @@ class QMNIST(MNIST):
366369
version. E.g, ``transforms.RandomCrop``
367370
target_transform (callable, optional): A function/transform
368371
that takes in the target and transforms it.
369-
train (bool,optional,compatibility): When argument 'what' is
370-
not specified, this boolean decides whether to load the
371-
training set or the testing set. Default: True.
372372
"""
373373

374374
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
@@ -514,7 +514,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
514514
data = f.read()
515515

516516
# parse
517-
if sys.byteorder == "little":
517+
if sys.byteorder == "little" or sys.platform == "aix":
518518
magic = get_int(data[0:4])
519519
nd = magic % 256
520520
ty = magic // 256
@@ -527,7 +527,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
527527
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
528528
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
529529

530-
if sys.byteorder == "big":
530+
if sys.byteorder == "big" and not sys.platform == "aix":
531531
for i in range(len(s)):
532532
s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
533533

torchvision/datasets/moving_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ class MovingMNIST(VisionDataset):
1818
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
1919
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
2020
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
21-
transform (callable, optional): A function/transform that takes in a torch Tensor
22-
and returns a transformed version. E.g, ``transforms.RandomCrop``
2321
download (bool, optional): If true, downloads the dataset from the internet and
2422
puts it in root directory. If dataset is already downloaded, it is not
2523
downloaded again.
24+
transform (callable, optional): A function/transform that takes in a torch Tensor
25+
and returns a transformed version. E.g, ``transforms.RandomCrop``
2626
"""
2727

2828
_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"

torchvision/datasets/oxford_iiit_pet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class OxfordIIITPet(VisionDataset):
2727
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
2828
version. E.g, ``transforms.RandomCrop``.
2929
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30+
transforms (callable, optional): A function/transform that takes input sample
31+
and its target as entry and returns a transformed version.
3032
download (bool, optional): If True, downloads the dataset from the internet and puts it into
3133
``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
3234
"""

torchvision/models/_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial
88
from inspect import signature
99
from types import ModuleType
10-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
10+
from typing import Any, Callable, Dict, get_args, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
1111

1212
from torch import nn
1313

@@ -168,14 +168,13 @@ def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
168168
if "weights" not in sig.parameters:
169169
raise ValueError("The method is missing the 'weights' argument.")
170170

171-
ann = signature(fn).parameters["weights"].annotation
171+
ann = sig.parameters["weights"].annotation
172172
weights_enum = None
173173
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
174174
weights_enum = ann
175175
else:
176176
# handle cases like Union[Optional, T]
177-
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
178-
for t in ann.__args__: # type: ignore[union-attr]
177+
for t in get_args(ann): # type: ignore[union-attr]
179178
if isinstance(t, type) and issubclass(t, WeightsEnum):
180179
weights_enum = t
181180
break

torchvision/ops/focal_loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def sigmoid_focal_loss(
2020
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
2121
classification label for each element in inputs
2222
(0 for the negative class and 1 for the positive class).
23-
alpha (float): Weighting factor in range (0,1) to balance
23+
alpha (float): Weighting factor in range [0, 1] to balance
2424
positive vs negative examples or -1 for ignore. Default: ``0.25``.
2525
gamma (float): Exponent of the modulating factor (1 - p_t) to
2626
balance easy vs hard examples. Default: ``2``.
@@ -33,6 +33,9 @@ def sigmoid_focal_loss(
3333
"""
3434
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
3535

36+
if not (0 <= alpha <= 1) or alpha != -1:
37+
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")
38+
3639
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
3740
_log_api_usage_once(sigmoid_focal_loss)
3841
p = torch.sigmoid(inputs)

0 commit comments

Comments
 (0)