From 06066743a6f518601a8b894a90e28db957d0b458 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sat, 19 Mar 2022 19:34:00 +0530 Subject: [PATCH 01/20] added usps dataset --- .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/usps.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 torchvision/prototype/datasets/_builtin/usps.py diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index feb558aa03f..1a8dc0907a4 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -17,4 +17,5 @@ from .semeion import SEMEION from .stanford_cars import StanfordCars from .svhn import SVHN +from .usps import USPS from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py new file mode 100644 index 00000000000..92efe1e582b --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -0,0 +1,72 @@ +import bz2 +import functools +from typing import Any, Dict, List, Tuple, BinaryIO, Iterator + +import numpy as np +import torch +from torchdata.datapipes.iter import IterDataPipe, IterableWrapper, LineReader, Mapper +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, Label + + +class USPSFileReader(IterDataPipe[torch.Tensor]): + def __init__(self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]]) -> None: + self.datapipe = datapipe + + def __iter__(self) -> Iterator[torch.Tensor]: + for path, _ in self.datapipe: + with bz2.open(path) as fp: + datapipe = IterableWrapper([(path, fp)]) + line_reader = LineReader(datapipe, decode=True) + for _, line in line_reader: + raw_data = line.split() + tmp_list = [x.split(":")[-1] for x in raw_data[1:]] + img = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) + img = ((img + 1) / 2 * 255).astype(dtype=np.uint8) + target = int(raw_data[0]) - 1 + yield torch.from_numpy(img), torch.tensor(target) + + +class USPS(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "usps", + homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", + valid_options=dict( + split=("train", "test"), + ), + categories=10, + ) + + _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" + + _RESOURCES = { + "train": HttpResource( + f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f" + ), + "test": HttpResource( + f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e" + ), + } + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [USPS._RESOURCES[config.split]] + + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + image, label = data + return dict( + image=Image(image), + label=Label(label, dtype=torch.int64, categories=self.categories), + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + ) -> IterDataPipe[Dict[str, Any]]: + dp = USPSFileReader(resource_dps[0]) + dp = hint_sharding(dp) + dp = hint_shuffling(dp) + return Mapper(dp, functools.partial(self._prepare_sample, config=config)) From ed267f6f7c000d8573db3e738fa7a88b0a6ce51a Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sat, 19 Mar 2022 22:03:56 +0530 Subject: [PATCH 02/20] fixed type issues --- torchvision/prototype/datasets/_builtin/usps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 92efe1e582b..4e302036bfa 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -14,7 +14,7 @@ class USPSFileReader(IterDataPipe[torch.Tensor]): def __init__(self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]]) -> None: self.datapipe = datapipe - def __iter__(self) -> Iterator[torch.Tensor]: + def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: for path, _ in self.datapipe: with bz2.open(path) as fp: datapipe = IterableWrapper([(path, fp)]) From 16123efe13749a076a8acc52c874acf97e3aeee3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 20 Mar 2022 11:25:22 +0100 Subject: [PATCH 03/20] fix mobilnet norm layer test (#5643) * xfail mobilnet norm layer test * fix test --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 209f27209bf..d657475bafb 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -406,7 +406,7 @@ def test_mobilenet_norm_layer(model_fn): assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) + return nn.GroupNorm(1, num_channels) model = model_fn(norm_layer=get_gn) assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) From a37b3e9576e0706534c59ee4b67e987696414d1e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 21 Mar 2022 15:56:12 +0000 Subject: [PATCH 04/20] More robust check in tests for 16 bits images (#5652) --- test/test_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_image.py b/test/test_image.py index f9e88e8ad2f..048f7c364b1 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -158,7 +158,7 @@ def test_decode_png(img_path, pil_mode, mode): img_pil = normalize_dimensions(img_pil) - if "16" in img_path: + if img_path.endswith("16.png"): # 16 bits image decoding is supported, but only as a private API # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"): From 7fd2ea01243264eed3db343b1dfa96651f6967a0 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 21 Mar 2022 10:58:56 -0700 Subject: [PATCH 05/20] Prefer nvidia channel for conda builds (#5648) To mitigate missing `libcupti.so` dependency --- .circleci/unittest/linux/scripts/install.sh | 4 ++-- packaging/build_cmake.sh | 2 +- packaging/build_conda.sh | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index b29fd3e09b8..59e033fea2b 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -21,7 +21,7 @@ else fi echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION" version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" - cudatoolkit="cudatoolkit=${version}" + cudatoolkit="nvidia::cudatoolkit=${version}" fi case "$(uname -s)" in @@ -33,7 +33,7 @@ printf "Installing PyTorch with %s\n" "${cudatoolkit}" if [ "${os}" == "MacOSX" ]; then conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest else - conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" pytest + conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c nvidia "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" pytest fi printf "* Installing torchvision\n" diff --git a/packaging/build_cmake.sh b/packaging/build_cmake.sh index 227b2b39519..17b62c55d82 100755 --- a/packaging/build_cmake.sh +++ b/packaging/build_cmake.sh @@ -42,7 +42,7 @@ else PYTORCH_MUTEX_CONSTRAINT='' fi -conda install -yq \pytorch=$PYTORCH_VERSION $CONDA_CUDATOOLKIT_CONSTRAINT $PYTORCH_MUTEX_CONSTRAINT $MKL_CONSTRAINT numpy -c "pytorch-${UPLOAD_CHANNEL}" +conda install -yq \pytorch=$PYTORCH_VERSION $CONDA_CUDATOOLKIT_CONSTRAINT $PYTORCH_MUTEX_CONSTRAINT $MKL_CONSTRAINT numpy -c nvidia -c "pytorch-${UPLOAD_CHANNEL}" TORCH_PATH=$(dirname $(python -c "import torch; print(torch.__file__)")) if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then diff --git a/packaging/build_conda.sh b/packaging/build_conda.sh index 2d5d72f3ad2..b127e0f10fa 100755 --- a/packaging/build_conda.sh +++ b/packaging/build_conda.sh @@ -18,4 +18,4 @@ if [[ "$CU_VERSION" == cu115 ]]; then export CUDATOOLKIT_CHANNEL="conda-forge" fi -conda build -c defaults -c $CUDATOOLKIT_CHANNEL $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchvision +conda build -c $CUDATOOLKIT_CHANNEL -c defaults $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchvision From 2efd0f2bb4222056bc92a3b565531c3874ba96db Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Mar 2022 10:27:27 +0100 Subject: [PATCH 06/20] fix torchdata CI installation (#5657) --- .circleci/config.yml | 10 ++++++++-- .circleci/config.yml.in | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c68f39642ca..e3211e5da29 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -155,8 +155,14 @@ commands: install_prototype_dependencies: steps: - pip_install: - args: iopath git+https://github.com/pytorch/data - descr: Install prototype dependencies + args: iopath + descr: Install third-party dependencies + - pip_install: + args: -r https://raw.githubusercontent.com/pytorch/data/main/requirements.txt + descr: Install torchdata build dependencies + - pip_install: + args: --no-build-isolation git+https://github.com/pytorch/data + descr: Install torchdata from source # Most of the test suite is handled by the `unittest` jobs, with completely different workflow and setup. # This command can be used if only a selection of tests need to be run, for ad-hoc files. diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 731d0133528..0d99a942338 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -155,8 +155,14 @@ commands: install_prototype_dependencies: steps: - pip_install: - args: iopath git+https://github.com/pytorch/data - descr: Install prototype dependencies + args: iopath + descr: Install third-party dependencies + - pip_install: + args: -r https://raw.githubusercontent.com/pytorch/data/main/requirements.txt + descr: Install torchdata build dependencies + - pip_install: + args: --no-build-isolation git+https://github.com/pytorch/data + descr: Install torchdata from source # Most of the test suite is handled by the `unittest` jobs, with completely different workflow and setup. # This command can be used if only a selection of tests need to be run, for ad-hoc files. From 834c8d9aef32559f8f60e9e68035ff6b8e499ed2 Mon Sep 17 00:00:00 2001 From: Sahil Goyal Date: Tue, 22 Mar 2022 18:22:10 +0530 Subject: [PATCH 07/20] update urls for kinetics dataset (#5578) * update urls for kinetics dataset * update urls for kinetics dataset * remove errors * update the changes and add test option to split * added test to valid values for split arg * change .txt to .csv for annotation url of k600 Co-authored-by: Nicolas Hug --- torchvision/datasets/kinetics.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index e1ac4ac500c..651dbdc158f 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -1,6 +1,7 @@ import csv import os import time +import urllib import warnings from functools import partial from multiprocessing import Pool @@ -53,7 +54,7 @@ class Kinetics(VisionDataset): Note: split is appended automatically using the split argument. frames_per_clip (int): number of frames in a clip num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700 - split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` + split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` ``"test"`` frame_rate (float): If omitted, interpolate different frame rate for each clip. step_between_clips (int): number of frames between each clip transform (callable, optional): A function/transform that takes in a TxHxWxC video @@ -81,7 +82,7 @@ class Kinetics(VisionDataset): } _ANNOTATION_URLS = { "400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv", - "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.txt", + "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.csv", "700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv", } @@ -122,7 +123,7 @@ def __init__( raise ValueError("Cannot download the videos using legacy_structure.") else: self.split_folder = path.join(root, split) - self.split = verify_str_arg(split, arg="split", valid_values=["train", "val"]) + self.split = verify_str_arg(split, arg="split", valid_values=["train", "val", "test"]) if download: self.download_and_process_videos() @@ -177,17 +178,16 @@ def _download_videos(self) -> None: split_url_filepath = path.join(file_list_path, path.basename(split_url)) if not check_integrity(split_url_filepath): download_url(split_url, file_list_path) - list_video_urls = open(split_url_filepath) + with open(split_url_filepath) as file: + list_video_urls = [urllib.parse.quote(line, safe="/,:") for line in file.read().splitlines()] if self.num_download_workers == 1: - for line in list_video_urls.readlines(): - line = str(line).replace("\n", "") + for line in list_video_urls: download_and_extract_archive(line, tar_path, self.split_folder) else: part = partial(_dl_wrap, tar_path, self.split_folder) - lines = [str(line).replace("\n", "") for line in list_video_urls.readlines()] poolproc = Pool(self.num_download_workers) - poolproc.map(part, lines) + poolproc.map(part, list_video_urls) def _make_ds_structure(self) -> None: """move videos from From 2ef164a3b98ee21a2bb628fe42f87f4977758021 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 22 Mar 2022 16:30:20 +0000 Subject: [PATCH 08/20] Port Multi-weight support from prototype to main (#5618) * Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet. * Porting googlenet * Porting inception * Porting mnasnet * Porting mobilenetv2 * Porting mobilenetv3 * Porting regnet * Porting resnet * Porting shufflenetv2 * Porting squeezenet * Porting vgg * Porting vit * Fix docstrings * Fixing imports * Adding missing import * Fix mobilenet imports * Fix tests * Fix prototype tests * Exclude get_weight from models on test * Fix init files * Porting googlenet * Porting inception * porting mobilenetv2 * porting mobilenetv3 * porting resnet * porting shufflenetv2 * Fix test and linter * Fixing docs. * Porting Detection models (#5617) * fix inits * fix docs * Port faster_rcnn * Port fcos * Port keypoint_rcnn * Port mask_rcnn * Port retinanet * Port ssd * Port ssdlite * Fix linter * Fixing tests * Fixing tests * Fixing vgg test * Porting Optical Flow, Segmentation, Video models (#5619) * Porting raft * Porting video resnet * Porting deeplabv3 * Porting fcn and lraspp * Fixing the tests and linter * Porting docs, examples, tutorials and galleries (#5620) * Fix examples, tutorials and gallery * Update gallery/plot_optical_flow.py Co-authored-by: Nicolas Hug * Fix import * Revert hardcoded normalization * fix uncommitted changes * Fix bug * Fix more bugs * Making resize optional for segmentation * Fixing preset * Fix mypy * Fixing documentation strings * Fix flake8 * minor refactoring Co-authored-by: Nicolas Hug * Resolve conflict * Porting model tests (#5622) * Porting tests * Remove unnecessary variable * Fix linter * Move prototype to extended tests * Fix download models job * Update CI on Multiweight branch to use the new weight download approach (#5628) * port Pad to prototype transforms (#5621) * port Pad to prototype transforms * use literal * Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624) Co-authored-by: Anton Thomma Co-authored-by: Vasilis Vryniotis * pre-download model weights in CI docs build (#5625) * pre-download model weights in CI docs build * move changes into template * change docs image * Regenerated config.yml Co-authored-by: Philip Meier Co-authored-by: Anton Thomma <11010310+thommaa@users.noreply.github.com> Co-authored-by: Anton Thomma * Porting reference scripts and updating presets (#5629) * Making _preset.py classes * Remove support of targets on presets. * Rewriting the video preset * Adding tests to check that the bundled transforms are JIT scriptable * Rename all presets from *Eval to *Inference * Minor refactoring * Remove --prototype and --pretrained from reference scripts * remove pretained_backbone refs * Corrections and simplifications * Fixing bug * Fixing linter * Fix flake8 * restore documentation example * minor fixes * fix optical flow missing param * Fixing commands * Adding weights_backbone support in detection and segmentation * Updating the commands for InceptionV3 * Setting `weights_backbone` to its fully BC value (#5653) * Replace default `weights_backbone=None` with its BC values. * Fixing tests * Fix linter * Update docs. * Update preprocessing on reference scripts. * Change qat/ptq to their full values. * Refactoring preprocessing * Fix video preset * No initialization on VGG if pretrained * Fix warning messages for backbone utils. * Adding star to all preset constructors. * Fix mypy. Co-authored-by: Nicolas Hug Co-authored-by: Philip Meier Co-authored-by: Anton Thomma <11010310+thommaa@users.noreply.github.com> Co-authored-by: Anton Thomma --- .circleci/config.yml | 20 +- .circleci/config.yml.in | 20 +- android/test_app/make_assets.py | 13 +- docs/source/models.rst | 54 +- examples/cpp/hello_world/trace_model.py | 2 +- gallery/plot_optical_flow.py | 34 +- gallery/plot_repurposing_annotations.py | 8 +- gallery/plot_scripted_tensor_transforms.py | 12 +- gallery/plot_visualization_utils.py | 39 +- ios/VisionTestApp/make_assets.py | 13 +- references/classification/README.md | 26 +- references/classification/presets.py | 2 + references/classification/train.py | 42 +- .../classification/train_quantization.py | 23 +- references/classification/utils.py | 8 +- references/detection/README.md | 18 +- references/detection/train.py | 48 +- references/optical_flow/README.md | 4 +- references/optical_flow/presets.py | 1 + references/optical_flow/train.py | 45 +- references/segmentation/README.md | 12 +- references/segmentation/presets.py | 4 +- references/segmentation/train.py | 57 +- references/video_classification/presets.py | 5 +- references/video_classification/train.py | 42 +- test/test_backbone_utils.py | 26 +- test/test_cpp_models.py | 59 +- ...type_models.py => test_extended_models.py} | 120 +--- test/test_hub.py | 4 +- test/test_models.py | 18 +- .../test_models_detection_negative_samples.py | 12 +- test/test_models_detection_utils.py | 10 +- test/test_onnx.py | 14 +- test/tracing/frcnn/trace_model.py | 2 +- torchvision/models/__init__.py | 16 +- torchvision/{prototype => }/models/_api.py | 40 +- torchvision/{prototype => }/models/_meta.py | 0 torchvision/models/_utils.py | 165 ++++- torchvision/models/alexnet.py | 53 +- torchvision/models/convnext.py | 142 ++++- torchvision/models/densenet.py | 147 ++++- torchvision/models/detection/__init__.py | 4 +- .../models/detection/backbone_utils.py | 35 +- torchvision/models/detection/faster_rcnn.py | 229 ++++--- torchvision/models/detection/fcos.py | 83 ++- torchvision/models/detection/keypoint_rcnn.py | 126 ++-- torchvision/models/detection/mask_rcnn.py | 89 ++- torchvision/models/detection/retinanet.py | 87 ++- torchvision/models/detection/roi_heads.py | 3 +- torchvision/models/detection/ssd.py | 101 +-- torchvision/models/detection/ssdlite.py | 92 ++- torchvision/models/efficientnet.py | 449 +++++++++++--- torchvision/models/googlenet.py | 75 ++- torchvision/models/inception.py | 68 +- torchvision/models/mnasnet.py | 140 +++-- torchvision/models/mobilenet.py | 6 +- torchvision/models/mobilenetv2.py | 69 ++- torchvision/models/mobilenetv3.py | 117 +++- torchvision/models/optical_flow/__init__.py | 2 +- torchvision/models/optical_flow/raft.py | 191 +++++- torchvision/models/quantization/__init__.py | 4 +- torchvision/models/quantization/googlenet.py | 107 ++-- torchvision/models/quantization/inception.py | 105 ++-- torchvision/models/quantization/mobilenet.py | 6 +- .../models/quantization/mobilenetv2.py | 85 ++- .../models/quantization/mobilenetv3.py | 100 ++- torchvision/models/quantization/resnet.py | 203 ++++-- .../models/quantization/shufflenetv2.py | 125 +++- torchvision/models/regnet.py | 583 +++++++++++++++--- torchvision/models/resnet.py | 395 ++++++++++-- torchvision/models/segmentation/__init__.py | 2 +- torchvision/models/segmentation/_utils.py | 8 - torchvision/models/segmentation/deeplabv3.py | 199 ++++-- torchvision/models/segmentation/fcn.py | 139 +++-- torchvision/models/segmentation/lraspp.py | 78 ++- torchvision/models/shufflenetv2.py | 138 ++++- torchvision/models/squeezenet.py | 93 ++- torchvision/models/vgg.py | 257 ++++++-- torchvision/models/video/resnet.py | 157 +++-- torchvision/models/vision_transformer.py | 147 ++++- torchvision/prototype/__init__.py | 1 - torchvision/prototype/models/__init__.py | 20 - torchvision/prototype/models/_utils.py | 108 ---- torchvision/prototype/models/alexnet.py | 49 -- torchvision/prototype/models/convnext.py | 169 ----- torchvision/prototype/models/densenet.py | 159 ----- .../prototype/models/detection/__init__.py | 7 - .../prototype/models/detection/faster_rcnn.py | 228 ------- .../prototype/models/detection/fcos.py | 80 --- .../models/detection/keypoint_rcnn.py | 108 ---- .../prototype/models/detection/mask_rcnn.py | 81 --- .../prototype/models/detection/retinanet.py | 84 --- torchvision/prototype/models/detection/ssd.py | 93 --- .../prototype/models/detection/ssdlite.py | 129 ---- torchvision/prototype/models/efficientnet.py | 453 -------------- torchvision/prototype/models/googlenet.py | 63 -- torchvision/prototype/models/inception.py | 57 -- torchvision/prototype/models/mnasnet.py | 113 ---- torchvision/prototype/models/mobilenet.py | 6 - torchvision/prototype/models/mobilenetv2.py | 66 -- torchvision/prototype/models/mobilenetv3.py | 109 ---- .../prototype/models/optical_flow/__init__.py | 1 - .../prototype/models/optical_flow/raft.py | 251 -------- .../prototype/models/quantization/__init__.py | 5 - .../models/quantization/googlenet.py | 94 --- .../models/quantization/inception.py | 90 --- .../models/quantization/mobilenet.py | 6 - .../models/quantization/mobilenetv2.py | 81 --- .../models/quantization/mobilenetv3.py | 101 --- .../prototype/models/quantization/resnet.py | 204 ------ .../models/quantization/shufflenetv2.py | 136 ---- torchvision/prototype/models/regnet.py | 575 ----------------- torchvision/prototype/models/resnet.py | 381 ------------ .../prototype/models/segmentation/__init__.py | 3 - .../models/segmentation/deeplabv3.py | 174 ------ .../prototype/models/segmentation/fcn.py | 117 ---- .../prototype/models/segmentation/lraspp.py | 66 -- torchvision/prototype/models/shufflenetv2.py | 124 ---- torchvision/prototype/models/squeezenet.py | 88 --- torchvision/prototype/models/vgg.py | 240 ------- .../prototype/models/video/__init__.py | 1 - torchvision/prototype/models/video/resnet.py | 152 ----- .../prototype/models/vision_transformer.py | 198 ------ torchvision/prototype/transforms/__init__.py | 9 - .../prototype/transforms/_auto_augment.py | 5 +- torchvision/prototype/transforms/_geometry.py | 4 +- .../transforms/functional/_geometry.py | 3 +- torchvision/prototype/utils/_internal.py | 55 -- .../{prototype => }/transforms/_presets.py | 103 ++-- 129 files changed, 4581 insertions(+), 7146 deletions(-) rename test/{test_prototype_models.py => test_extended_models.py} (70%) rename torchvision/{prototype => }/models/_api.py (72%) rename torchvision/{prototype => }/models/_meta.py (100%) delete mode 100644 torchvision/prototype/models/__init__.py delete mode 100644 torchvision/prototype/models/_utils.py delete mode 100644 torchvision/prototype/models/alexnet.py delete mode 100644 torchvision/prototype/models/convnext.py delete mode 100644 torchvision/prototype/models/densenet.py delete mode 100644 torchvision/prototype/models/detection/__init__.py delete mode 100644 torchvision/prototype/models/detection/faster_rcnn.py delete mode 100644 torchvision/prototype/models/detection/fcos.py delete mode 100644 torchvision/prototype/models/detection/keypoint_rcnn.py delete mode 100644 torchvision/prototype/models/detection/mask_rcnn.py delete mode 100644 torchvision/prototype/models/detection/retinanet.py delete mode 100644 torchvision/prototype/models/detection/ssd.py delete mode 100644 torchvision/prototype/models/detection/ssdlite.py delete mode 100644 torchvision/prototype/models/efficientnet.py delete mode 100644 torchvision/prototype/models/googlenet.py delete mode 100644 torchvision/prototype/models/inception.py delete mode 100644 torchvision/prototype/models/mnasnet.py delete mode 100644 torchvision/prototype/models/mobilenet.py delete mode 100644 torchvision/prototype/models/mobilenetv2.py delete mode 100644 torchvision/prototype/models/mobilenetv3.py delete mode 100644 torchvision/prototype/models/optical_flow/__init__.py delete mode 100644 torchvision/prototype/models/optical_flow/raft.py delete mode 100644 torchvision/prototype/models/quantization/__init__.py delete mode 100644 torchvision/prototype/models/quantization/googlenet.py delete mode 100644 torchvision/prototype/models/quantization/inception.py delete mode 100644 torchvision/prototype/models/quantization/mobilenet.py delete mode 100644 torchvision/prototype/models/quantization/mobilenetv2.py delete mode 100644 torchvision/prototype/models/quantization/mobilenetv3.py delete mode 100644 torchvision/prototype/models/quantization/resnet.py delete mode 100644 torchvision/prototype/models/quantization/shufflenetv2.py delete mode 100644 torchvision/prototype/models/regnet.py delete mode 100644 torchvision/prototype/models/resnet.py delete mode 100644 torchvision/prototype/models/segmentation/__init__.py delete mode 100644 torchvision/prototype/models/segmentation/deeplabv3.py delete mode 100644 torchvision/prototype/models/segmentation/fcn.py delete mode 100644 torchvision/prototype/models/segmentation/lraspp.py delete mode 100644 torchvision/prototype/models/shufflenetv2.py delete mode 100644 torchvision/prototype/models/squeezenet.py delete mode 100644 torchvision/prototype/models/vgg.py delete mode 100644 torchvision/prototype/models/video/__init__.py delete mode 100644 torchvision/prototype/models/video/resnet.py delete mode 100644 torchvision/prototype/models/vision_transformer.py rename torchvision/{prototype => }/transforms/_presets.py (53%) diff --git a/.circleci/config.yml b/.circleci/config.yml index e3211e5da29..8c8600d17ad 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -366,19 +366,28 @@ jobs: resource_class: xlarge steps: - checkout - - download_model_weights: - extract_roots: torchvision/prototype/models - install_torchvision - install_prototype_dependencies - pip_install: args: scipy pycocotools h5py descr: Install optional dependencies - - run: - name: Enable prototype tests - command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV - run_tests_selective: file_or_dir: test/test_prototype_*.py + unittest_extended: + docker: + - image: circleci/python:3.7 + resource_class: xlarge + steps: + - checkout + - download_model_weights + - install_torchvision + - run: + name: Enable extended tests + command: echo 'export PYTORCH_TEST_WITH_EXTENDED=1' >> $BASH_ENV + - run_tests_selective: + file_or_dir: test/test_extended_*.py + binary_linux_wheel: <<: *binary_common docker: @@ -1629,6 +1638,7 @@ workflows: - unittest_torchhub - unittest_onnx - unittest_prototype + - unittest_extended - unittest_linux_cpu: cu_version: cpu name: unittest_linux_cpu_py3.7 diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 0d99a942338..7188d435a82 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -366,19 +366,28 @@ jobs: resource_class: xlarge steps: - checkout - - download_model_weights: - extract_roots: torchvision/prototype/models - install_torchvision - install_prototype_dependencies - pip_install: args: scipy pycocotools h5py descr: Install optional dependencies - - run: - name: Enable prototype tests - command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV - run_tests_selective: file_or_dir: test/test_prototype_*.py + unittest_extended: + docker: + - image: circleci/python:3.7 + resource_class: xlarge + steps: + - checkout + - download_model_weights + - install_torchvision + - run: + name: Enable extended tests + command: echo 'export PYTORCH_TEST_WITH_EXTENDED=1' >> $BASH_ENV + - run_tests_selective: + file_or_dir: test/test_extended_*.py + binary_linux_wheel: <<: *binary_common docker: @@ -1115,6 +1124,7 @@ workflows: - unittest_torchhub - unittest_onnx - unittest_prototype + - unittest_extended {{ unittest_workflows() }} cmake: diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py index fedee39fc52..f99933e9a9d 100644 --- a/android/test_app/make_assets.py +++ b/android/test_app/make_assets.py @@ -1,11 +1,18 @@ import torch -import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile +from torchvision.models.detection import ( + fasterrcnn_mobilenet_v3_large_320_fpn, + FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, +) print(torch.__version__) -model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +model = fasterrcnn_mobilenet_v3_large_320_fpn( + weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT, + box_score_thresh=0.7, + rpn_post_nms_top_n_test=100, + rpn_score_thresh=0.4, + rpn_pre_nms_top_n_test=150, ) model.eval() diff --git a/docs/source/models.rst b/docs/source/models.rst index 50af05360e4..39543cb8027 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -98,58 +98,6 @@ You can construct a model with random weights by calling its constructor: convnext_large = models.convnext_large() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. -These can be constructed by passing ``pretrained=True``: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18(pretrained=True) - alexnet = models.alexnet(pretrained=True) - squeezenet = models.squeezenet1_0(pretrained=True) - vgg16 = models.vgg16(pretrained=True) - densenet = models.densenet161(pretrained=True) - inception = models.inception_v3(pretrained=True) - googlenet = models.googlenet(pretrained=True) - shufflenet = models.shufflenet_v2_x1_0(pretrained=True) - mobilenet_v2 = models.mobilenet_v2(pretrained=True) - mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True) - mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True) - resnext50_32x4d = models.resnext50_32x4d(pretrained=True) - wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) - mnasnet = models.mnasnet1_0(pretrained=True) - efficientnet_b0 = models.efficientnet_b0(pretrained=True) - efficientnet_b1 = models.efficientnet_b1(pretrained=True) - efficientnet_b2 = models.efficientnet_b2(pretrained=True) - efficientnet_b3 = models.efficientnet_b3(pretrained=True) - efficientnet_b4 = models.efficientnet_b4(pretrained=True) - efficientnet_b5 = models.efficientnet_b5(pretrained=True) - efficientnet_b6 = models.efficientnet_b6(pretrained=True) - efficientnet_b7 = models.efficientnet_b7(pretrained=True) - efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True) - efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True) - efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True) - regnet_y_400mf = models.regnet_y_400mf(pretrained=True) - regnet_y_800mf = models.regnet_y_800mf(pretrained=True) - regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True) - regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True) - regnet_y_8gf = models.regnet_y_8gf(pretrained=True) - regnet_y_16gf = models.regnet_y_16gf(pretrained=True) - regnet_y_32gf = models.regnet_y_32gf(pretrained=True) - regnet_x_400mf = models.regnet_x_400mf(pretrained=True) - regnet_x_800mf = models.regnet_x_800mf(pretrained=True) - regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True) - regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True) - regnet_x_8gf = models.regnet_x_8gf(pretrained=True) - regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue) - regnet_x_32gf = models.regnet_x_32gf(pretrained=True) - vit_b_16 = models.vit_b_16(pretrained=True) - vit_b_32 = models.vit_b_32(pretrained=True) - vit_l_16 = models.vit_l_16(pretrained=True) - vit_l_32 = models.vit_l_32(pretrained=True) - convnext_tiny = models.convnext_tiny(pretrained=True) - convnext_small = models.convnext_small(pretrained=True) - convnext_base = models.convnext_base(pretrained=True) - convnext_large = models.convnext_large(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -525,7 +473,7 @@ Obtaining a pre-trained quantized model can be done with a few lines of code: .. code:: python import torchvision.models as models - model = models.quantization.mobilenet_v2(pretrained=True, quantize=True) + model = models.quantization.mobilenet_v2(weights=MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1, quantize=True) model.eval() # run the model with quantized inputs and weights out = model(torch.rand(1, 3, 224, 224)) diff --git a/examples/cpp/hello_world/trace_model.py b/examples/cpp/hello_world/trace_model.py index c8b8d6911e7..41bbaf8b6dd 100644 --- a/examples/cpp/hello_world/trace_model.py +++ b/examples/cpp/hello_world/trace_model.py @@ -6,7 +6,7 @@ HERE = osp.dirname(osp.abspath(__file__)) ASSETS = osp.dirname(osp.dirname(HERE)) -model = torchvision.models.resnet18(pretrained=False) +model = torchvision.models.resnet18() model.eval() traced_model = torch.jit.script(model) diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 505334f36da..5149ebc541b 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -19,7 +19,6 @@ import torch import matplotlib.pyplot as plt import torchvision.transforms.functional as F -import torchvision.transforms as T plt.rcParams["savefig.bbox"] = "tight" @@ -88,24 +87,19 @@ def plot(imgs, **imshow_kwargs): # reduce the image sizes for the example to run faster. Image dimension must be # divisible by 8. +from torchvision.models.optical_flow import Raft_Large_Weights -def preprocess(batch): - transforms = T.Compose( - [ - T.ConvertImageDtype(torch.float32), - T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] - T.Resize(size=(520, 960)), - ] - ) - batch = transforms(batch) - return batch +weights = Raft_Large_Weights.DEFAULT +transforms = weights.transforms() -# If you can, run this example on a GPU, it will be a lot faster. -device = "cuda" if torch.cuda.is_available() else "cpu" +def preprocess(img1_batch, img2_batch): + img1_batch = F.resize(img1_batch, size=[520, 960]) + img2_batch = F.resize(img2_batch, size=[520, 960]) + return transforms(img1_batch, img2_batch) + -img1_batch = preprocess(img1_batch).to(device) -img2_batch = preprocess(img2_batch).to(device) +img1_batch, img2_batch = preprocess(img1_batch, img2_batch) print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}") @@ -121,7 +115,10 @@ def preprocess(batch): from torchvision.models.optical_flow import raft_large -model = raft_large(pretrained=True, progress=False).to(device) +# If you can, run this example on a GPU, it will be a lot faster. +device = "cuda" if torch.cuda.is_available() else "cpu" + +model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) model = model.eval() list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) @@ -182,10 +179,9 @@ def preprocess(batch): # from torchvision.io import write_jpeg # for i, (img1, img2) in enumerate(zip(frames, frames[1:])): # # Note: it would be faster to predict batches of flows instead of individual flows -# img1 = preprocess(img1[None]).to(device) -# img2 = preprocess(img2[None]).to(device) +# img1, img2 = preprocess(img1, img2) -# list_of_flows = model(img1_batch, img2_batch) +# list_of_flows = model(img1.to(device), img1.to(device)) # predicted_flow = list_of_flows[-1][0] # flow_img = flow_to_image(predicted_flow).to("cpu") # output_folder = "/tmp/" # Update this to the folder of your choice diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index fb4835496c3..7bb68617a17 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -139,12 +139,14 @@ def show(imgs): # Here is demo with a Faster R-CNN model loaded from # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` -from torchvision.models.detection import fasterrcnn_resnet50_fpn +from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights -model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT +model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) print(img.size()) -img = F.convert_image_dtype(img, torch.float) +tranforms = weights.transforms() +img = tranforms(img) target = {} target["boxes"] = boxes target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) diff --git a/gallery/plot_scripted_tensor_transforms.py b/gallery/plot_scripted_tensor_transforms.py index a9205536821..995383d4603 100644 --- a/gallery/plot_scripted_tensor_transforms.py +++ b/gallery/plot_scripted_tensor_transforms.py @@ -85,20 +85,16 @@ def show(imgs): # Let's define a ``Predictor`` module that transforms the input tensor and then # applies an ImageNet model on it. -from torchvision.models import resnet18 +from torchvision.models import resnet18, ResNet18_Weights class Predictor(nn.Module): def __init__(self): super().__init__() - self.resnet18 = resnet18(pretrained=True, progress=False).eval() - self.transforms = nn.Sequential( - T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions - T.CenterCrop(224), - T.ConvertImageDtype(torch.float), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ) + weights = ResNet18_Weights.DEFAULT + self.resnet18 = resnet18(weights=weights, progress=False).eval() + self.transforms = weights.transforms() def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 526c8c32493..7f92d54ebdd 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -73,14 +73,17 @@ def show(imgs): # :func:`~torchvision.models.detection.ssd300_vgg16`. For more details # on the output of such models, you may refer to :ref:`instance_seg_output`. -from torchvision.models.detection import fasterrcnn_resnet50_fpn -from torchvision.transforms.functional import convert_image_dtype +from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights batch_int = torch.stack([dog1_int, dog2_int]) -batch = convert_image_dtype(batch_int, dtype=torch.float) -model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +batch = transforms(batch_int) + +model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() outputs = model(batch) @@ -120,13 +123,15 @@ def show(imgs): # images must be normalized before they're passed to a semantic segmentation # model. -from torchvision.models.segmentation import fcn_resnet50 +from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights +weights = FCN_ResNet50_Weights.DEFAULT +transforms = weights.transforms(resize_size=None) -model = fcn_resnet50(pretrained=True, progress=False) +model = fcn_resnet50(weights=weights, progress=False) model = model.eval() -normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +normalized_batch = transforms(batch) output = model(normalized_batch)['out'] print(output.shape, output.min().item(), output.max().item()) @@ -262,8 +267,14 @@ def show(imgs): # of them may not have masks, like # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`. -from torchvision.models.detection import maskrcnn_resnet50_fpn -model = maskrcnn_resnet50_fpn(pretrained=True, progress=False) +from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights + +weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +batch = transforms(batch_int) + +model = maskrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() output = model(batch) @@ -378,13 +389,17 @@ def show(imgs): # Note that the keypoint detection model does not need normalized images. # -from torchvision.models.detection import keypointrcnn_resnet50_fpn +from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from torchvision.io import read_image person_int = read_image(str(Path("assets") / "person1.jpg")) -person_float = convert_image_dtype(person_int, dtype=torch.float) -model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +person_float = transforms(person_int) + +model = keypointrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() outputs = model([person_float]) diff --git a/ios/VisionTestApp/make_assets.py b/ios/VisionTestApp/make_assets.py index 0f46364569b..f14223e6a42 100644 --- a/ios/VisionTestApp/make_assets.py +++ b/ios/VisionTestApp/make_assets.py @@ -1,11 +1,18 @@ import torch -import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile +from torchvision.models.detection import ( + fasterrcnn_mobilenet_v3_large_320_fpn, + FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, +) print(torch.__version__) -model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +model = fasterrcnn_mobilenet_v3_large_320_fpn( + weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT, + box_score_thresh=0.7, + rpn_post_nms_top_n_test=100, + rpn_score_thresh=0.4, + rpn_pre_nms_top_n_test=150, ) model.eval() diff --git a/references/classification/README.md b/references/classification/README.md index 173fb454995..c274c997791 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -43,7 +43,7 @@ Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model ``` torchrun --nproc_per_node=8 train.py --model inception_v3\ - --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained + --test-only --weights Inception_V3_Weights.IMAGENET1K_V1 ``` ### ResNet @@ -96,22 +96,14 @@ The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTo All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands: ``` -torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\ - --val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\ - --val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\ - --val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\ - --val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\ - --val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\ - --val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --test-only --weights EfficientNet_B0_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --test-only --weights EfficientNet_B1_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --test-only --weights EfficientNet_B2_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --test-only --weights EfficientNet_B3_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --test-only --weights EfficientNet_B4_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --test-only --weights EfficientNet_B5_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --test-only --weights EfficientNet_B6_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --test-only --weights EfficientNet_B7_Weights.IMAGENET1K_V1 ``` diff --git a/references/classification/presets.py b/references/classification/presets.py index 418ef3e2e07..6bc38e72953 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -6,6 +6,7 @@ class ClassificationPresetTrain: def __init__( self, + *, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), @@ -46,6 +47,7 @@ def __call__(self, img): class ClassificationPresetEval: def __init__( self, + *, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), diff --git a/references/classification/train.py b/references/classification/train.py index 569cf3009e7..eb8b56c1ad0 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -15,12 +15,6 @@ from torchvision.transforms.functional import InterpolationMode -try: - from torchvision import prototype -except ImportError: - prototype = None - - def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") @@ -154,18 +148,13 @@ def load_data(traindir, valdir, args): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: - if not args.prototype: + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + preprocessing = weights.transforms() + else: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) - else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.ImageClassificationEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation - ) dataset_test = torchvision.datasets.ImageFolder( valdir, @@ -191,10 +180,6 @@ def load_data(traindir, valdir, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -236,10 +221,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) - else: - model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) + model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: @@ -446,12 +428,6 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") @@ -496,14 +472,6 @@ def get_args_parser(add_help=True): parser.add_argument( "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index 111777a860b..c0e5af1dcfc 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -12,17 +12,7 @@ from train import train_one_epoch, evaluate, load_data -try: - from torchvision import prototype -except ImportError: - prototype = None - - def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -56,10 +46,7 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - if not args.prototype: - model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) - else: - model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): @@ -264,14 +251,6 @@ def get_args_parser(add_help=True): "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 7f573415c4c..32658a7c137 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -330,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T from torchvision import models as M # Classification - model = M.mobilenet_v3_large(pretrained=False) + model = M.mobilenet_v3_large(weights=None) print(store_model_weights(model, './class.pth')) # Quantized Classification - model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) + model = M.quantization.mobilenet_v3_large(weights=None, quantize=False) model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') _ = torch.ao.quantization.prepare_qat(model, inplace=True) print(store_model_weights(model, './qat.pth')) # Object Detection - model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False) + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None) print(store_model_weights(model, './obj.pth')) # Segmentation - model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True) + model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True) print(store_model_weights(model, './segm.pth', strict=False)) Args: diff --git a/references/detection/README.md b/references/detection/README.md index 3695644138b..aec7c10e1b5 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -24,35 +24,35 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs. ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large 320 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### FCOS ResNet-50 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fcos_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### RetinaNet ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model retinanet_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### SSD300 VGG16 @@ -60,7 +60,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssd300_vgg16 --epochs 120\ --lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\ - --weight-decay 0.0005 --data-augmentation ssd + --weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES ``` ### SSDlite320 MobileNetV3-Large @@ -68,7 +68,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\ --aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\ - --weight-decay 0.00004 --data-augmentation ssdlite + --weight-decay 0.00004 --data-augmentation ssdlite --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` @@ -76,7 +76,7 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model maskrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` @@ -84,5 +84,5 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\ - --lr-steps 36 43 --aspect-ratio-group-factor 3 + --lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` diff --git a/references/detection/train.py b/references/detection/train.py index 3909e6413d0..b6634061503 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -33,12 +33,6 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups -try: - from torchvision import prototype -except ImportError: - prototype = None - - def get_dataset(name, image_set, transform, data_path): paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] @@ -49,15 +43,13 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: - return presets.DetectionPresetTrain(args.data_augmentation) - elif not args.prototype: - return presets.DetectionPresetEval() + return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation) + elif args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img, target: (trans(img), target) else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.ObjectDetectionEval() + return presets.DetectionPresetEval() def get_args_parser(add_help=True): @@ -132,25 +124,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") @@ -159,10 +138,6 @@ def get_args_parser(add_help=True): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -204,12 +179,9 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - if not args.prototype: - model = torchvision.models.detection.__dict__[args.model]( - pretrained=args.pretrained, num_classes=num_classes, **kwargs - ) - else: - model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) + model = torchvision.models.detection.__dict__[args.model]( + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs + ) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md index a7620ce4be6..a7ac0223739 100644 --- a/references/optical_flow/README.md +++ b/references/optical_flow/README.md @@ -51,7 +51,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ ### Evaluation ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 ``` This should give an epe of about 1.3822 on the clean pass and 2.7161 on the @@ -67,6 +67,6 @@ Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: You can also evaluate on Kitti train: ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679 ``` diff --git a/references/optical_flow/presets.py b/references/optical_flow/presets.py index 43ff4a24f3b..32d9542e692 100644 --- a/references/optical_flow/presets.py +++ b/references/optical_flow/presets.py @@ -22,6 +22,7 @@ def forward(self, img1, img2, flow, valid): class OpticalFlowPresetTrain(torch.nn.Module): def __init__( self, + *, # RandomResizeAndCrop params crop_size, min_scale=-0.2, diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 83952242eb9..5070cb554d4 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -9,11 +9,6 @@ from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K -try: - from torchvision import prototype -except ImportError: - prototype = None - def get_train_dataset(stage, dataset_root): if stage == "chairs": @@ -138,12 +133,18 @@ def inner_loop(blob): def evaluate(model, args): val_datasets = args.val_dataset or [] - if args.prototype: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.OpticalFlowEval() + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + + def preprocessing(img1, img2, flow, valid_flow_mask): + img1, img2 = trans(img1, img2) + if flow is not None and not isinstance(flow, torch.Tensor): + flow = torch.from_numpy(flow) + if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor): + valid_flow_mask = torch.from_numpy(valid_flow_mask) + return img1, img2, flow, valid_flow_mask + else: preprocessing = OpticalFlowPresetEval() @@ -201,20 +202,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) + args.test_only = args.train_dataset is None if args.distributed and args.device == "cpu": raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") device = torch.device(args.device) - if args.prototype: - model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights) - else: - model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) + model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) if args.distributed: model = model.to(args.local_rank) @@ -228,7 +223,7 @@ def main(args): checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) - if args.train_dataset is None: + if args.test_only: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True @@ -356,8 +351,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small" ) - # TODO: resume, pretrained, and weights should be in an exclusive arg group - parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") + # TODO: resume and weights should be in an exclusive arg group parser.add_argument( "--num_flow_updates", @@ -376,13 +370,6 @@ def get_args_parser(add_help=True): required=True, ) - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") diff --git a/references/segmentation/README.md b/references/segmentation/README.md index e9b5391215a..2c7391c8380 100644 --- a/references/segmentation/README.md +++ b/references/segmentation/README.md @@ -14,30 +14,30 @@ You must modify the following flags: ## fcn_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## fcn_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ## lraspp_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 8cada98ac95..ed02ae660e4 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -3,7 +3,7 @@ class SegmentationPresetTrain: - def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): min_size = int(0.5 * base_size) max_size = int(2.0 * base_size) @@ -25,7 +25,7 @@ def __call__(self, img, target): class SegmentationPresetEval: - def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): self.transforms = T.Compose( [ T.RandomResize(base_size, base_size), diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 5dc03945bd7..e8570ab7f69 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -9,12 +9,7 @@ import utils from coco_utils import get_coco from torch import nn - - -try: - from torchvision import prototype -except ImportError: - prototype = None +from torchvision.transforms import functional as F, InterpolationMode def get_dataset(dir_path, name, image_set, transform): @@ -35,14 +30,19 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: return presets.SegmentationPresetTrain(base_size=520, crop_size=480) - elif not args.prototype: - return presets.SegmentationPresetEval(base_size=520) + elif args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + + def preprocessing(img, target): + img = trans(img) + size = F.get_dimensions(img)[1:] + target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) + return img, F.pil_to_tensor(target) + + return preprocessing else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.SemanticSegmentationEval(resize_size=520) + return presets.SegmentationPresetEval(base_size=520) def criterion(inputs, target): @@ -100,10 +100,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -135,16 +131,9 @@ def main(args): dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) - if not args.prototype: - model = torchvision.models.segmentation.__dict__[args.model]( - pretrained=args.pretrained, - num_classes=num_classes, - aux_loss=args.aux_loss, - ) - else: - model = prototype.models.segmentation.__dict__[args.model]( - weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss - ) + model = torchvision.models.segmentation.__dict__[args.model]( + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss + ) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -272,24 +261,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 04039c9a4f1..c12d00a022b 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -6,8 +6,9 @@ class VideoClassificationPresetTrain: def __init__( self, - resize_size, + *, crop_size, + resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), hflip_prob=0.5, @@ -27,7 +28,7 @@ def __call__(self, x): class VideoClassificationPresetEval: - def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): self.transforms = transforms.Compose( [ ConvertBHWCtoBCHW(), diff --git a/references/video_classification/train.py b/references/video_classification/train.py index d36785ddf96..918a012282e 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,11 +12,6 @@ from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler -try: - from torchvision import prototype -except ImportError: - prototype = None - def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): model.train() @@ -96,17 +91,11 @@ def collate_fn(batch): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) - print("torch version: ", torch.__version__) - print("torchvision version: ", torchvision.__version__) device = torch.device(args.device) @@ -120,7 +109,7 @@ def main(args): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) + transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") @@ -150,14 +139,11 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - if not args.prototype: - transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + transform_test = weights.transforms() else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - transform_test = weights.transforms() - else: - transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171)) + transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -208,10 +194,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) - else: - model = prototype.models.video.__dict__[args.model](weights=args.weights) + model = torchvision.models.video.__dict__[args.model](weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -352,24 +335,11 @@ def parse_args(): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index a3ba427f1de..60d8f8d167d 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -13,36 +13,40 @@ def get_available_models(): # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + return [ + k + for k, v in models.__dict__.items() + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" + ] @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") - model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False) + model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None) assert isinstance(model, BackboneWithFPN) y = model(x) assert list(y.keys()) == ["0", "1", "2", "3", "pool"] with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): - resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=6) + resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3]) + resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5]) + resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5]) @pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small")) def test_mobilenet_backbone(backbone_name): with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): - mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1) + mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2]) + mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6]) - model_fpn = mobilenet_backbone(backbone_name, False, fpn=True) + mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6]) + model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True) assert isinstance(model_fpn, BackboneWithFPN) - model = mobilenet_backbone(backbone_name, False, fpn=False) + model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False) assert isinstance(model, torch.nn.Sequential) @@ -96,7 +100,7 @@ def forward(self, x): class TestFxFeatureExtraction: inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu") - model_defaults = {"num_classes": 1, "pretrained": False} + model_defaults = {"num_classes": 1} leaf_modules = [] def _create_feature_extractor(self, *args, **kwargs): diff --git a/test/test_cpp_models.py b/test/test_cpp_models.py index f7cce7b6c43..d8d0836d499 100644 --- a/test/test_cpp_models.py +++ b/test/test_cpp_models.py @@ -53,50 +53,49 @@ def read_image2(): "see https://github.com/pytorch/vision/issues/1191", ) class Tester(unittest.TestCase): - pretrained = False image = read_image1() def test_alexnet(self): - process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet") + process_model(models.alexnet(), self.image, _C_tests.forward_alexnet, "Alexnet") def test_vgg11(self): - process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11") + process_model(models.vgg11(), self.image, _C_tests.forward_vgg11, "VGG11") def test_vgg13(self): - process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13") + process_model(models.vgg13(), self.image, _C_tests.forward_vgg13, "VGG13") def test_vgg16(self): - process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16") + process_model(models.vgg16(), self.image, _C_tests.forward_vgg16, "VGG16") def test_vgg19(self): - process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19") + process_model(models.vgg19(), self.image, _C_tests.forward_vgg19, "VGG19") def test_vgg11_bn(self): - process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN") + process_model(models.vgg11_bn(), self.image, _C_tests.forward_vgg11bn, "VGG11BN") def test_vgg13_bn(self): - process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN") + process_model(models.vgg13_bn(), self.image, _C_tests.forward_vgg13bn, "VGG13BN") def test_vgg16_bn(self): - process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN") + process_model(models.vgg16_bn(), self.image, _C_tests.forward_vgg16bn, "VGG16BN") def test_vgg19_bn(self): - process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN") + process_model(models.vgg19_bn(), self.image, _C_tests.forward_vgg19bn, "VGG19BN") def test_resnet18(self): - process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18") + process_model(models.resnet18(), self.image, _C_tests.forward_resnet18, "Resnet18") def test_resnet34(self): - process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34") + process_model(models.resnet34(), self.image, _C_tests.forward_resnet34, "Resnet34") def test_resnet50(self): - process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50") + process_model(models.resnet50(), self.image, _C_tests.forward_resnet50, "Resnet50") def test_resnet101(self): - process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101") + process_model(models.resnet101(), self.image, _C_tests.forward_resnet101, "Resnet101") def test_resnet152(self): - process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152") + process_model(models.resnet152(), self.image, _C_tests.forward_resnet152, "Resnet152") def test_resnext50_32x4d(self): process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d") @@ -111,48 +110,44 @@ def test_wide_resnet101_2(self): process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2") def test_squeezenet1_0(self): - process_model( - models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0" - ) + process_model(models.squeezenet1_0(), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0") def test_squeezenet1_1(self): - process_model( - models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1" - ) + process_model(models.squeezenet1_1(), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1") def test_densenet121(self): - process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121") + process_model(models.densenet121(), self.image, _C_tests.forward_densenet121, "Densenet121") def test_densenet169(self): - process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169") + process_model(models.densenet169(), self.image, _C_tests.forward_densenet169, "Densenet169") def test_densenet201(self): - process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201") + process_model(models.densenet201(), self.image, _C_tests.forward_densenet201, "Densenet201") def test_densenet161(self): - process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161") + process_model(models.densenet161(), self.image, _C_tests.forward_densenet161, "Densenet161") def test_mobilenet_v2(self): - process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet") + process_model(models.mobilenet_v2(), self.image, _C_tests.forward_mobilenetv2, "MobileNet") def test_googlenet(self): - process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet") + process_model(models.googlenet(), self.image, _C_tests.forward_googlenet, "GoogLeNet") def test_mnasnet0_5(self): - process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5") + process_model(models.mnasnet0_5(), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5") def test_mnasnet0_75(self): - process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75") + process_model(models.mnasnet0_75(), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75") def test_mnasnet1_0(self): - process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0") + process_model(models.mnasnet1_0(), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0") def test_mnasnet1_3(self): - process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3") + process_model(models.mnasnet1_3(), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3") def test_inception_v3(self): self.image = read_image2() - process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3") + process_model(models.inception_v3(), self.image, _C_tests.forward_inceptionv3, "Inceptionv3") if __name__ == "__main__": diff --git a/test/test_prototype_models.py b/test/test_extended_models.py similarity index 70% rename from test/test_prototype_models.py rename to test/test_extended_models.py index 6c7234e2ef0..a07b501e15b 100644 --- a/test/test_prototype_models.py +++ b/test/test_extended_models.py @@ -4,21 +4,15 @@ import pytest import test_models as TM import torch -from common_utils import cpu_and_gpu, needs_cuda -from torchvision.prototype import models -from torchvision.prototype.models._api import WeightsEnum, Weights -from torchvision.prototype.models._utils import handle_legacy_interface - -run_if_test_with_prototype = pytest.mark.skipif( - os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1", - reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.", -) +from torchvision import models +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._utils import handle_legacy_interface -def _get_original_model(model_fn): - original_module_name = model_fn.__module__.replace(".prototype", "") - module = importlib.import_module(original_module_name) - return module.__dict__[model_fn.__name__] +run_if_test_with_extended = pytest.mark.skipif( + os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1", + reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.", +) def _get_parent_module(model_fn): @@ -40,17 +34,6 @@ def _get_model_weights(model_fn): return None -def _build_model(fn, **kwargs): - try: - model = fn(**kwargs) - except ValueError as e: - msg = str(e) - if "No checkpoint is available" in msg: - pytest.skip(msg) - raise e - return model.eval() - - @pytest.mark.parametrize( "name, weight", [ @@ -95,7 +78,7 @@ def test_naming_conventions(model_fn): + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), ) -@run_if_test_with_prototype +@run_if_test_with_extended def test_schema_meta_validation(model_fn): classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"] defaults = { @@ -142,48 +125,6 @@ def test_schema_meta_validation(model_fn): assert not bad_names -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_classification_model(model_fn, dev): - TM.test_classification_model(model_fn, dev) - - -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_detection_model(model_fn, dev): - TM.test_detection_model(model_fn, dev) - - -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization)) -@run_if_test_with_prototype -def test_quantized_classification_model(model_fn): - TM.test_quantized_classification_model(model_fn) - - -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_segmentation_model(model_fn, dev): - TM.test_segmentation_model(model_fn, dev) - - -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_video_model(model_fn, dev): - TM.test_video_model(model_fn, dev) - - -@needs_cuda -@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow)) -@pytest.mark.parametrize("scripted", (False, True)) -@run_if_test_with_prototype -def test_raft(model_builder, scripted): - TM.test_raft(model_builder, scripted) - - @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(models) @@ -193,9 +134,13 @@ def test_raft(model_builder, scripted): + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), ) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_old_vs_new_factory(model_fn, dev): +@run_if_test_with_extended +def test_transforms_jit(model_fn): + model_name = model_fn.__name__ + weights_enum = _get_model_weights(model_fn) + if len(weights_enum) == 0: + pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") + defaults = { "models": { "input_shape": (1, 3, 224, 224), @@ -205,43 +150,36 @@ def test_old_vs_new_factory(model_fn, dev): }, "quantization": { "input_shape": (1, 3, 224, 224), - "quantize": True, }, "segmentation": { "input_shape": (1, 3, 520, 520), }, "video": { - "input_shape": (1, 3, 4, 112, 112), + "input_shape": (1, 4, 112, 112, 3), }, "optical_flow": { "input_shape": (1, 3, 128, 128), }, } - model_name = model_fn.__name__ module_name = model_fn.__module__.split(".")[-2] - kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} - input_shape = kwargs.pop("input_shape") - kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models - x = torch.rand(input_shape).to(device=dev) - if module_name == "detection": - x = [x] + kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") + x = torch.rand(input_shape) if module_name == "optical_flow": - args = [x, x] # RAFT model requires img1, img2 as input + args = (x, x) else: - args = [x] - - # compare with new model builder parameterized in the old fashion way - try: - model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) - model_new = _build_model(model_fn, **kwargs).to(device=dev) - except ModuleNotFoundError: - pytest.skip(f"Model '{model_name}' not available in both modules.") - torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False) + args = (x,) + problematic_weights = [] + for w in weights_enum: + transforms = w.transforms() + try: + TM._check_jit_scriptable(transforms, args) + except Exception: + problematic_weights.append(w) -def test_smoke(): - import torchvision.prototype.models # noqa: F401 + assert not problematic_weights # With this filter, every unexpected warning will be turned into an error diff --git a/test/test_hub.py b/test/test_hub.py index 5c791bf9d7a..d88c6fa2cd2 100644 --- a/test/test_hub.py +++ b/test/test_hub.py @@ -26,13 +26,13 @@ class TestHub: # Python cache as we run all hub tests in the same python process. def test_load_from_github(self): - hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) def test_set_dir(self): temp_dir = tempfile.gettempdir() hub.set_dir(temp_dir) - hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) assert os.path.exists(temp_dir + "/pytorch_vision_master") shutil.rmtree(temp_dir + "/pytorch_vision_master") diff --git a/test/test_models.py b/test/test_models.py index d657475bafb..9a051a61eab 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,8 +133,7 @@ def get_export_import_copy(m): if eager_out is None: with torch.no_grad(), freeze_rng_state(): - if unwrapper: - eager_out = nn_module(*args) + eager_out = nn_module(*args) with torch.no_grad(), freeze_rng_state(): script_out = sm(*args) @@ -414,7 +413,6 @@ def get_gn(num_channels): def test_inception_v3_eval(): - # replacement for models.inception_v3(pretrained=True) that does not download weights kwargs = {} kwargs["transform_input"] = True kwargs["aux_logits"] = True @@ -430,7 +428,7 @@ def test_inception_v3_eval(): def test_fasterrcnn_double(): - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None) model.double() model.eval() input_shape = (3, 300, 300) @@ -446,7 +444,6 @@ def test_fasterrcnn_double(): def test_googlenet_eval(): - # replacement for models.googlenet(pretrained=True) that does not download weights kwargs = {} kwargs["transform_input"] = True kwargs["aux_logits"] = True @@ -470,7 +467,7 @@ def checkOut(out): assert "scores" in out[0] assert "labels" in out[0] - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None) model.cuda() model.eval() input_shape = (3, 300, 300) @@ -586,7 +583,7 @@ def test_segmentation_model(model_fn, dev): set_rng_seed(0) defaults = { "num_classes": 10, - "pretrained_backbone": False, + "weights_backbone": None, "input_shape": (1, 3, 32, 32), } model_name = model_fn.__name__ @@ -648,7 +645,7 @@ def test_detection_model(model_fn, dev): set_rng_seed(0) defaults = { "num_classes": 50, - "pretrained_backbone": False, + "weights_backbone": None, "input_shape": (3, 300, 300), } model_name = model_fn.__name__ @@ -743,7 +740,7 @@ def compute_mean_std(tensor): @pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) def test_detection_model_validation(model_fn): set_rng_seed(0) - model = model_fn(num_classes=50, pretrained_backbone=False) + model = model_fn(num_classes=50, weights=None, weights_backbone=None) input_shape = (3, 300, 300) x = [torch.rand(input_shape)] @@ -807,7 +804,6 @@ def test_quantized_classification_model(model_fn): defaults = { "num_classes": 5, "input_shape": (1, 3, 224, 224), - "pretrained": False, "quantize": True, } model_name = model_fn.__name__ @@ -857,7 +853,7 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load max_trainable = _model_tests_values[model_name]["max_trainable"] n_trainable_params = [] for trainable_layers in range(0, max_trainable + 1): - model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) + model = model_fn(weights=None, weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers) n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 7d2953f7e64..c4efbd96cf3 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -100,7 +100,7 @@ def test_assign_targets_to_proposals(self): ) def test_forward_negative_sample_frcnn(self, name): model = torchvision.models.detection.__dict__[name]( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 ) images, targets = self._make_empty_sample() @@ -111,7 +111,7 @@ def test_forward_negative_sample_frcnn(self, name): def test_forward_negative_sample_mrcnn(self): model = torchvision.models.detection.maskrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 ) images, targets = self._make_empty_sample(add_masks=True) @@ -123,7 +123,7 @@ def test_forward_negative_sample_mrcnn(self): def test_forward_negative_sample_krcnn(self): model = torchvision.models.detection.keypointrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 ) images, targets = self._make_empty_sample(add_keypoints=True) @@ -135,7 +135,7 @@ def test_forward_negative_sample_krcnn(self): def test_forward_negative_sample_retinanet(self): model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 ) images, targets = self._make_empty_sample() @@ -145,7 +145,7 @@ def test_forward_negative_sample_retinanet(self): def test_forward_negative_sample_fcos(self): model = torchvision.models.detection.fcos_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 ) images, targets = self._make_empty_sample() @@ -155,7 +155,7 @@ def test_forward_negative_sample_fcos(self): assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0)) def test_forward_negative_sample_ssd(self): - model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False) + model = torchvision.models.detection.ssd300_vgg16(weights=None, weights_backbone=None, num_classes=2) images, targets = self._make_empty_sample() loss_dict = model(images, targets) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index 44abfd51a7f..a160113cbbf 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -40,7 +40,7 @@ def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # be frozen for each trainable_backbone_layers parameter value # i.e all 53 params are frozen if trainable_backbone_layers=0 # ad first 24 params are frozen if trainable_backbone_layers=2 - model = backbone_utils.resnet_fpn_backbone("resnet50", pretrained=False, trainable_layers=train_layers) + model = backbone_utils.resnet_fpn_backbone("resnet50", weights=None, trainable_layers=train_layers) # boolean list that is true if the param at that index is frozen is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] # check that expected initial number of layers are frozen @@ -49,18 +49,18 @@ def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3 + is_trained=True, trainable_backbone_layers=None, max_value=5, default_value=3 ) assert ret == 3 # can't go beyond 5 with pytest.raises(ValueError, match=r"Trainable backbone layers should be in the range"): ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3 + is_trained=True, trainable_backbone_layers=6, max_value=5, default_value=3 ) - # if not pretrained, should use all trainable layers and warn + # if not trained, should use all trainable layers and warn with pytest.warns(UserWarning): ret = backbone_utils._validate_trainable_layers( - pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3 + is_trained=False, trainable_backbone_layers=0, max_value=5, default_value=3 ) assert ret == 5 diff --git a/test/test_onnx.py b/test/test_onnx.py index b725cdf2b90..375d0fd1c6f 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -430,7 +430,9 @@ def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def test_faster_rcnn(self): images, test_images = self.get_test_images() dummy_image = [torch.ones(3, 100, 100) * 0.3] - model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn( + weights=models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300 + ) model.eval() model(images) # Test exported model on images of different size, or dummy input @@ -486,7 +488,9 @@ def test_paste_mask_in_image(self): def test_mask_rcnn(self): images, test_images = self.get_test_images() dummy_image = [torch.ones(3, 100, 100) * 0.3] - model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn( + weights=models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300 + ) model.eval() model(images) # Test exported model on images of different size, or dummy input @@ -548,7 +552,9 @@ def test_heatmaps_to_keypoints(self): def test_keypoint_rcnn(self): images, test_images = self.get_test_images() dummy_images = [torch.ones(3, 100, 100) * 0.3] - model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn( + weights=models.detection.keypoint_rcnn.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300 + ) model.eval() model(images) self.run_model( @@ -570,7 +576,7 @@ def test_keypoint_rcnn(self): ) def test_shufflenet_v2_dynamic_axes(self): - model = models.shufflenet_v2_x0_5(pretrained=True) + model = models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT) dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0) diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 8cc1d344936..b5ec50bdab1 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -6,7 +6,7 @@ HERE = osp.dirname(osp.abspath(__file__)) ASSETS = osp.dirname(osp.dirname(HERE)) -model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) +model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None) model.eval() traced_model = torch.jit.script(model) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 16495e8552e..83e49908348 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -1,20 +1,20 @@ from .alexnet import * from .convnext import * -from .resnet import * -from .vgg import * -from .squeezenet import * -from .inception import * from .densenet import * +from .efficientnet import * from .googlenet import * -from .mobilenet import * +from .inception import * from .mnasnet import * -from .shufflenetv2 import * -from .efficientnet import * +from .mobilenet import * from .regnet import * +from .resnet import * +from .shufflenetv2 import * +from .squeezenet import * +from .vgg import * from .vision_transformer import * from . import detection -from . import feature_extraction from . import optical_flow from . import quantization from . import segmentation from . import video +from ._api import get_weight diff --git a/torchvision/prototype/models/_api.py b/torchvision/models/_api.py similarity index 72% rename from torchvision/prototype/models/_api.py rename to torchvision/models/_api.py index 85b280a7dfc..e47eaf73aab 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/models/_api.py @@ -3,11 +3,12 @@ import sys from collections import OrderedDict from dataclasses import dataclass, fields -from typing import Any, Callable, Dict +from inspect import signature +from typing import Any, Callable, Dict, cast from torchvision._utils import StrEnum -from ..._internally_replaced_utils import load_state_dict_from_url +from .._internally_replaced_utils import load_state_dict_from_url __all__ = ["WeightsEnum", "Weights", "get_weight"] @@ -105,3 +106,38 @@ def get_weight(name: str) -> WeightsEnum: raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") return weights_enum.from_str(value_name) + + +def get_enum_from_fn(fn: Callable) -> WeightsEnum: + """ + Internal method that gets the weight enum of a specific model builder method. + Might be removed after the handle_legacy_interface is removed. + + Args: + fn (Callable): The builder method used to create the model. + weight_name (str): The name of the weight enum entry of the specific model. + Returns: + WeightsEnum: The requested weight enum. + """ + sig = signature(fn) + if "weights" not in sig.parameters: + raise ValueError("The method is missing the 'weights' argument.") + + ann = signature(fn).parameters["weights"].annotation + weights_enum = None + if isinstance(ann, type) and issubclass(ann, WeightsEnum): + weights_enum = ann + else: + # handle cases like Union[Optional, T] + # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 + for t in ann.__args__: # type: ignore[union-attr] + if isinstance(t, type) and issubclass(t, WeightsEnum): + weights_enum = t + break + + if weights_enum is None: + raise ValueError( + "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." + ) + + return cast(WeightsEnum, weights_enum) diff --git a/torchvision/prototype/models/_meta.py b/torchvision/models/_meta.py similarity index 100% rename from torchvision/prototype/models/_meta.py rename to torchvision/models/_meta.py diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index f4e1cd84508..08c878a8a67 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,8 +1,14 @@ +import functools +import inspect +import warnings from collections import OrderedDict -from typing import Dict, Optional +from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union from torch import nn +from .._utils import sequence_to_str +from ._api import WeightsEnum + class IntermediateLayerGetter(nn.ModuleDict): """ @@ -26,7 +32,7 @@ class IntermediateLayerGetter(nn.ModuleDict): Examples:: - >>> m = torchvision.models.resnet18(pretrained=True) + >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT) >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, >>> {'layer1': 'feat1', 'layer3': 'feat2'}) @@ -81,3 +87,158 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> if new_v < 0.9 * v: new_v += divisor return new_v + + +D = TypeVar("D") + + +def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: + """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. + + For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: + + .. code:: + + def old_fn(foo, bar, baz=None): + ... + + def new_fn(foo, *, bar, baz=None): + ... + + Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC + and at the same time warn the user of the deprecation, this decorator can be used: + + .. code:: + + @kwonly_to_pos_or_kw + def new_fn(foo, *, bar, baz=None): + ... + + new_fn("foo", "bar, "baz") + """ + params = inspect.signature(fn).parameters + + try: + keyword_only_start_idx = next( + idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY + ) + except StopIteration: + raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None + + keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> D: + args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] + if keyword_only_args: + keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) + warnings.warn( + f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " + f"parameter(s) is deprecated. Please use keyword parameter(s) instead." + ) + kwargs.update(keyword_only_kwargs) + + return fn(*args, **kwargs) + + return wrapper + + +W = TypeVar("W", bound=WeightsEnum) +M = TypeVar("M", bound=nn.Module) +V = TypeVar("V") + + +def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): + """Decorates a model builder with the new interface to make it compatible with the old. + + In particular this handles two things: + + 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See + :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. + 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to + ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. + + Args: + **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter + name and default value for the legacy ``pretrained=True``. The default value can be a callable in which + case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in + the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters + should be accessed with :meth:`~dict.get`. + """ + + def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: + @kwonly_to_pos_or_kw + @functools.wraps(builder) + def inner_wrapper(*args: Any, **kwargs: Any) -> M: + for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] + # If neither the weights nor the pretrained parameter as passed, or the weights argument already use + # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the + # weight argument, since it is a valid value. + sentinel = object() + weights_arg = kwargs.get(weights_param, sentinel) + if ( + (weights_param not in kwargs and pretrained_param not in kwargs) + or isinstance(weights_arg, WeightsEnum) + or (isinstance(weights_arg, str) and weights_arg != "legacy") + or weights_arg is None + ): + continue + + # If the pretrained parameter was passed as positional argument, it is now mapped to + # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current + # signature to infer the names of positionally passed arguments and thus has no knowledge that there + # used to be a pretrained parameter. + pretrained_positional = weights_arg is not sentinel + if pretrained_positional: + # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a + # unified access to the value if the default value is a callable. + kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) + else: + pretrained_arg = kwargs[pretrained_param] + + if pretrained_arg: + default_weights_arg = default(kwargs) if callable(default) else default + if not isinstance(default_weights_arg, WeightsEnum): + raise ValueError(f"No weights available for model {builder.__name__}") + else: + default_weights_arg = None + + if not pretrained_positional: + warnings.warn( + f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." + ) + + msg = ( + f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " + f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." + ) + if pretrained_arg: + msg = ( + f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " + f"to get the most up-to-date weights." + ) + warnings.warn(msg) + + del kwargs[pretrained_param] + kwargs[weights_param] = default_weights_arg + + return builder(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + +def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: + if param in kwargs: + if kwargs[param] != new_value: + raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") + else: + kwargs[param] = new_value + + +def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: + if param is not None: + if param != new_value: + raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") + return new_value diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index bb812febdc4..6ee5b98c673 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,18 +1,17 @@ -from typing import Any +from functools import partial +from typing import Any, Optional import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["AlexNet", "alexnet"] - - -model_urls = { - "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", -} +__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] class AlexNet(nn.Module): @@ -53,17 +52,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: +class AlexNet_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "AlexNet", + "publication_year": 2012, + "num_params": 61100840, + "size": (224, 224), + "min_size": (63, 63), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", + "acc@1": 56.522, + "acc@5": 79.066, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) +def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: r"""AlexNet model architecture from the `"One weird trick..." `_ paper. The required minimum input size of the model is 63x63. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (AlexNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = AlexNet_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = AlexNet(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 3a0dcdb31cd..8774b9a1bc2 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -1,18 +1,25 @@ from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence import torch from torch import nn, Tensor from torch.nn import functional as F -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "ConvNeXt", + "ConvNeXt_Tiny_Weights", + "ConvNeXt_Small_Weights", + "ConvNeXt_Base_Weights", + "ConvNeXt_Large_Weights", "convnext_tiny", "convnext_small", "convnext_base", @@ -20,14 +27,6 @@ ] -_MODELS_URLS: Dict[str, Optional[str]] = { - "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - "convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth", - "convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth", - "convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth", -} - - class LayerNorm2d(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: x = x.permute(0, 2, 3, 1) @@ -187,29 +186,101 @@ def forward(self, x: Tensor) -> Tensor: def _convnext( - arch: str, block_setting: List[CNBlockConfig], stochastic_depth_prob: float, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ConvNeXt: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - if pretrained: - if arch not in _MODELS_URLS: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +_COMMON_META = { + "task": "image_classification", + "architecture": "ConvNeXt", + "publication_year": 2022, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", +} + + +class ConvNeXt_Tiny_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=236), + meta={ + **_COMMON_META, + "num_params": 28589128, + "acc@1": 82.520, + "acc@5": 96.146, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Small_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_small-0c510722.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=230), + meta={ + **_COMMON_META, + "num_params": 50223688, + "acc@1": 83.616, + "acc@5": 96.650, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Base_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88591464, + "acc@1": 84.062, + "acc@5": 96.870, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Large_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 197767336, + "acc@1": 84.414, + "acc@5": 96.976, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) +def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Tiny model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Tiny_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Tiny_Weights.verify(weights) + block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -217,16 +288,21 @@ def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) +def convnext_small( + *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: r"""ConvNeXt Small model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Small_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Small_Weights.verify(weights) + block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -234,16 +310,19 @@ def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) - return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) +def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Base model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Base_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Base_Weights.verify(weights) + block_setting = [ CNBlockConfig(128, 256, 3), CNBlockConfig(256, 512, 3), @@ -251,16 +330,21 @@ def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(1024, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) +def convnext_large( + *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: r"""ConvNeXt Large model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Large_Weights.verify(weights) + block_setting = [ CNBlockConfig(192, 384, 3), CNBlockConfig(384, 768, 3), @@ -268,4 +352,4 @@ def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(1536, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 14e318360af..2ffb29c54cb 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,6 +1,7 @@ import re from collections import OrderedDict -from typing import Any, List, Tuple +from functools import partial +from typing import Any, List, Optional, Tuple import torch import torch.nn as nn @@ -8,18 +9,24 @@ import torch.utils.checkpoint as cp from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"] - -model_urls = { - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", - "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", -} +__all__ = [ + "DenseNet", + "DenseNet121_Weights", + "DenseNet161_Weights", + "DenseNet169_Weights", + "DenseNet201_Weights", + "densenet121", + "densenet161", + "densenet169", + "densenet201", +] class _DenseLayer(nn.Module): @@ -220,7 +227,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: +def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used @@ -229,7 +236,7 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) - state_dict = load_state_dict_from_url(model_url, progress=progress) + state_dict = weights.get_state_dict(progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: @@ -240,71 +247,155 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: def _densenet( - arch: str, growth_rate: int, block_config: Tuple[int, int, int, int], num_init_features: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> DenseNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) - if pretrained: - _load_state_dict(model, model_urls[arch], progress) + + if weights is not None: + _load_state_dict(model=model, weights=weights, progress=progress) + return model -def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "DenseNet", + "publication_year": 2016, + "size": (224, 224), + "min_size": (29, 29), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/pull/116", +} + + +class DenseNet121_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet121-a639ec97.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 7978856, + "acc@1": 74.434, + "acc@5": 91.972, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet161_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet161-8d451a50.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 28681000, + "acc@1": 77.138, + "acc@5": 93.560, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet169_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 14149480, + "acc@1": 75.600, + "acc@5": 92.806, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet201_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet201-c1103571.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 20013928, + "acc@1": 76.896, + "acc@5": 93.370, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) +def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet121_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs) + weights = DenseNet121_Weights.verify(weights) + return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) -def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: + +@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) +def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-161 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet161_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs) + weights = DenseNet161_Weights.verify(weights) + + return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) -def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: +@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) +def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-169 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet169_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs) + weights = DenseNet169_Weights.verify(weights) + return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) -def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: + +@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) +def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-201 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet201_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs) + weights = DenseNet201_Weights.verify(weights) + + return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index be46f950a61..4146651c737 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,7 +1,7 @@ from .faster_rcnn import * -from .mask_rcnn import * +from .fcos import * from .keypoint_rcnn import * +from .mask_rcnn import * from .retinanet import * from .ssd import * from .ssdlite import * -from .fcos import * diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 5ac5f179479..24215322b84 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,7 +6,8 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._utils import IntermediateLayerGetter +from .._api import WeightsEnum, get_enum_from_fn +from .._utils import IntermediateLayerGetter, handle_legacy_interface class BackboneWithFPN(nn.Module): @@ -55,9 +56,16 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return x +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + ), +) def resnet_fpn_backbone( + *, backbone_name: str, - pretrained: bool, + weights: Optional[WeightsEnum], norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 3, returned_layers: Optional[List[int]] = None, @@ -69,7 +77,7 @@ def resnet_fpn_backbone( Examples:: >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone - >>> backbone = resnet_fpn_backbone('resnet50', pretrained=True, trainable_layers=3) + >>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3) >>> # get some dummy image >>> x = torch.rand(1,3,64,64) >>> # compute the output @@ -85,10 +93,10 @@ def resnet_fpn_backbone( Args: backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' - pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet + weights (WeightsEnum, optional): The pretrained weights for the model norm_layer (callable): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) - trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block. + trainable_layers (int): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``. By default all layers are returned. @@ -98,7 +106,7 @@ def resnet_fpn_backbone( a new list of feature maps and their corresponding names. By default a ``LastLevelMaxPool`` is used. """ - backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) @@ -135,13 +143,13 @@ def _resnet_fpn_extractor( def _validate_trainable_layers( - pretrained: bool, + is_trained: bool, trainable_backbone_layers: Optional[int], max_value: int, default_value: int, ) -> int: # don't freeze any layers if pretrained model or backbone is not used - if not pretrained: + if not is_trained: if trainable_backbone_layers is not None: warnings.warn( "Changing trainable_backbone_layers has not effect if " @@ -160,16 +168,23 @@ def _validate_trainable_layers( return trainable_backbone_layers +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + ), +) def mobilenet_backbone( + *, backbone_name: str, - pretrained: bool, + weights: Optional[WeightsEnum], fpn: bool, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 2, returned_layers: Optional[List[int]] = None, extra_blocks: Optional[ExtraFPNBlock] = None, ) -> nn.Module: - backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 35cb968d711..2c1e6358c58 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,11 +1,16 @@ +from typing import Any, Optional, Union + import torch.nn.functional as F from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..mobilenetv3 import mobilenet_v3_large -from ..resnet import resnet50 +from ...transforms._presets import ObjectDetection, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor @@ -17,9 +22,12 @@ __all__ = [ "FasterRCNN", + "FasterRCNN_ResNet50_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", ] @@ -109,7 +117,7 @@ class FasterRCNN(GeneralizedRCNN): >>> from torchvision.models.detection.rpn import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # FasterRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -316,16 +324,70 @@ def forward(self, x): return scores, bbox_deltas -model_urls = { - "fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - "fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - "fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", +_COMMON_META = { + "task": "image_object_detection", + "architecture": "FasterRCNN", + "publication_year": 2015, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, } +class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 41755286, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", + "map": 37.0, + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", + "map": 32.8, + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", + "map": 22.8, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fasterrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. @@ -362,7 +424,7 @@ def fasterrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) >>> # For training >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4) >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4] @@ -384,51 +446,60 @@ def fasterrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = FasterRCNN(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model def _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=False, - progress=True, - num_classes=91, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, -): - is_trained = pretrained or pretrained_backbone + *, + weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], + progress: bool, + num_classes: Optional[int], + weights_backbone: Optional[MobileNet_V3_Large_Weights], + trainable_backbone_layers: Optional[int], + **kwargs: Any, +) -> FasterRCNN: + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - pretrained_backbone = False - - backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) - anchor_sizes = ( ( 32, @@ -439,21 +510,29 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ), ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN( backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs ) - if pretrained: - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -462,21 +541,23 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco" + weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + defaults = { "min_size": 320, "max_size": 640, @@ -487,19 +568,28 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=pretrained, + weights=weights, progress=progress, num_classes=num_classes, - pretrained_backbone=pretrained_backbone, + weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def fasterrcnn_mobilenet_v3_large_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -508,32 +598,33 @@ def fasterrcnn_mobilenet_v3_large_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_MobileNet_V3_Large_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco" + weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + defaults = { "rpn_score_thresh": 0.05, } kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=pretrained, + weights=weights, progress=progress, num_classes=num_classes, - pretrained_backbone=pretrained_backbone, + weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index c15702f5e18..7b1b3f87ba8 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -2,25 +2,32 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Tuple, Optional import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss, generalized_box_iou_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once -from ..resnet import resnet50 +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform -__all__ = ["FCOS", "fcos_resnet50_fpn"] +__all__ = [ + "FCOS", + "FCOS_ResNet50_FPN_Weights", + "fcos_resnet50_fpn", +] class FCOSHead(nn.Module): @@ -318,7 +325,7 @@ class FCOS(nn.Module): >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # FCOS needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -636,19 +643,37 @@ def forward( return self.eager_outputs(losses, detections) -model_urls = { - "fcos_resnet50_fpn_coco": "https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", -} +class FCOS_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", + transforms=ObjectDetection, + meta={ + "task": "image_object_detection", + "architecture": "FCOS", + "publication_year": 2019, + "num_params": 32269600, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", + "map": 39.2, + }, + ) + DEFAULT = COCO_V1 +@handle_legacy_interface( + weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fcos_resnet50_fpn( - pretrained: bool = False, + *, + weights: Optional[FCOS_ResNet50_FPN_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, - **kwargs, -): + **kwargs: Any, +) -> FCOS: """ Constructs a FCOS model with a ResNet-50-FPN backbone. @@ -682,34 +707,40 @@ def fcos_resnet50_fpn( Example: - >>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FCOS_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. Default: None """ - is_trained = pretrained or pretrained_backbone + weights = FCOS_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = FCOS(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["fcos_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index aadd390afc8..dc03c693e1c 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,16 +1,25 @@ +from typing import Any, Optional + import torch from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..resnet import resnet50 +from ...transforms._presets import ObjectDetection, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN -__all__ = ["KeypointRCNN", "keypointrcnn_resnet50_fpn"] +__all__ = [ + "KeypointRCNN", + "KeypointRCNN_ResNet50_FPN_Weights", + "keypointrcnn_resnet50_fpn", +] class KeypointRCNN(FasterRCNN): @@ -110,7 +119,7 @@ class KeypointRCNN(FasterRCNN): >>> >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # KeypointRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -296,22 +305,61 @@ def forward(self, x): ) -model_urls = { - # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606 - "keypointrcnn_resnet50_fpn_coco_legacy": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - "keypointrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", +_COMMON_META = { + "task": "image_object_detection", + "architecture": "KeypointRCNN", + "publication_year": 2017, + "categories": _COCO_PERSON_CATEGORIES, + "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, + "interpolation": InterpolationMode.BILINEAR, } +class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_LEGACY = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/issues/1606", + "map": 50.6, + "map_kp": 61.1, + }, + ) + COCO_V1 = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", + "map": 54.6, + "map_kp": 65.0, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY + if kwargs["pretrained"] == "legacy" + else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1, + ), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def keypointrcnn_resnet50_fpn( - pretrained=False, - progress=True, - num_classes=2, - num_keypoints=17, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, -): + *, + weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + num_keypoints: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> KeypointRCNN: """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. @@ -350,7 +398,7 @@ def keypointrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) @@ -359,31 +407,39 @@ def keypointrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (KeypointRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - num_keypoints (int): number of keypoints, default 17 - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + num_keypoints (int, optional): number of keypoints + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) + else: + if num_classes is None: + num_classes = 2 + if num_keypoints is None: + num_keypoints = 17 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) - if pretrained: - key = "keypointrcnn_resnet50_fpn_coco" - if pretrained == "legacy": - key += "_legacy" - state_dict = load_state_dict_from_url(model_urls[key], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index c733613452a..a6cb731c0df 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -1,17 +1,23 @@ from collections import OrderedDict +from typing import Any, Optional from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..resnet import resnet50 +from ...transforms._presets import ObjectDetection, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN + __all__ = [ "MaskRCNN", + "MaskRCNN_ResNet50_FPN_Weights", "maskrcnn_resnet50_fpn", ] @@ -112,7 +118,7 @@ class MaskRCNN(FasterRCNN): >>> >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # MaskRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -299,14 +305,38 @@ def __init__(self, in_channels, dim_reduced, num_classes): # nn.init.constant_(param, 0) -model_urls = { - "maskrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", -} - - +class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", + transforms=ObjectDetection, + meta={ + "task": "image_object_detection", + "architecture": "MaskRCNN", + "publication_year": 2017, + "num_params": 44401393, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", + "map": 37.9, + "map_mask": 34.6, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def maskrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> MaskRCNN: """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. @@ -346,7 +376,7 @@ def maskrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) @@ -355,27 +385,34 @@ def maskrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (MaskRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = MaskRCNN(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 6d6463d6894..2242c1e09bb 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,18 +1,21 @@ import math import warnings from collections import OrderedDict -from typing import Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once -from ..resnet import resnet50 +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from . import _utils as det_utils from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator @@ -20,7 +23,11 @@ from .transform import GeneralizedRCNNTransform -__all__ = ["RetinaNet", "retinanet_resnet50_fpn"] +__all__ = [ + "RetinaNet", + "RetinaNet_ResNet50_FPN_Weights", + "retinanet_resnet50_fpn", +] def _sum(x: List[Tensor]) -> Tensor: @@ -286,7 +293,7 @@ class RetinaNet(nn.Module): >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # RetinaNet needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -578,14 +585,37 @@ def forward(self, images, targets=None): return self.eager_outputs(losses, detections) -model_urls = { - "retinanet_resnet50_fpn_coco": "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", -} +class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", + transforms=ObjectDetection, + meta={ + "task": "image_object_detection", + "architecture": "RetinaNet", + "publication_year": 2017, + "num_params": 34014999, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", + "map": 36.4, + }, + ) + DEFAULT = COCO_V1 +@handle_legacy_interface( + weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def retinanet_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> RetinaNet: """ Constructs a RetinaNet model with a ResNet-50-FPN backbone. @@ -619,36 +649,43 @@ def retinanet_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (RetinaNet_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) # skip P2 because it generates too many anchors (according to their paper) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = RetinaNet(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 9f2ef20d17c..d2abebfca68 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -4,8 +4,7 @@ import torch.nn.functional as F import torchvision from torch import nn, Tensor -from torchvision.ops import boxes as box_ops -from torchvision.ops import roi_align +from torchvision.ops import boxes as box_ops, roi_align from . import _utils as det_utils diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index bd7f1b2863f..a3b8ffda178 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -6,27 +6,42 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import boxes as box_ops +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once -from .. import vgg +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..vgg import VGG, VGG16_Weights, vgg16 from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .transform import GeneralizedRCNNTransform -__all__ = ["SSD", "ssd300_vgg16"] -model_urls = { - "ssd300_vgg16_coco": "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", -} - -backbone_urls = { - # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the - # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth - # Only the `features` weights have proper values, those on the `classifier` module are filled with nans. - "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth" -} +__all__ = [ + "SSD300_VGG16_Weights", + "ssd300_vgg16", +] + + +class SSD300_VGG16_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", + transforms=ObjectDetection, + meta={ + "task": "image_object_detection", + "architecture": "SSD", + "publication_year": 2015, + "num_params": 35641826, + "size": (300, 300), + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", + "map": 25.1, + }, + ) + DEFAULT = COCO_V1 def _xavier_init(conv: nn.Module): @@ -528,7 +543,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): +def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int): backbone = backbone.features # Gather the indices of maxpools. These are the locations of output blocks. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] @@ -546,14 +561,19 @@ def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): return SSDFeatureExtractorVGG(backbone, highres) +@handle_legacy_interface( + weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), +) def ssd300_vgg16( - pretrained: bool = False, + *, + weights: Optional[SSD300_VGG16_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[VGG16_Weights] = VGG16_Weights.IMAGENET1K_FEATURES, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, -): +) -> SSD: """Constructs an SSD model with input size 300x300 and a VGG16 backbone. Reference: `"SSD: Single Shot MultiBox Detector" `_. @@ -585,37 +605,38 @@ def ssd300_vgg16( Example: - >>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True) + >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (SSD300_VGG16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (VGG16_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 4. """ + weights = SSD300_VGG16_Weights.verify(weights) + weights_backbone = VGG16_Weights.verify(weights_backbone) + if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the argument.") + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 4 + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 ) - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - # Use custom backbones more appropriate for SSD - backbone = vgg.vgg16(pretrained=False, progress=progress) - if pretrained_backbone: - state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress) - backbone.load_state_dict(state_dict) - + backbone = vgg16(weights=weights_backbone, progress=progress) backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator( [[2], [2, 3], [2, 3], [2, 3], [2], [2]], @@ -628,12 +649,10 @@ def ssd300_vgg16( "image_mean": [0.48235, 0.45882, 0.40784], "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor } - kwargs = {**defaults, **kwargs} + kwargs: Any = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) - if pretrained: - weights_name = "ssd300_vgg16_coco" - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 1c59814f8d4..2e890356417 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -6,21 +6,24 @@ import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .ssd import SSD, SSDScoringHead -__all__ = ["ssdlite320_mobilenet_v3_large"] - -model_urls = { - "ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth" -} +__all__ = [ + "SSDLite320_MobileNet_V3_Large_Weights", + "ssdlite320_mobilenet_v3_large", +] # Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper @@ -181,15 +184,39 @@ def _mobilenet_extractor( return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer) +class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", + transforms=ObjectDetection, + meta={ + "task": "image_object_detection", + "architecture": "SSDLite", + "publication_year": 2018, + "num_params": 3440060, + "size": (320, 320), + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", + "map": 21.3, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def ssdlite320_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = False, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, -): +) -> SSD: """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at `"Searching for MobileNetV3" `_ and @@ -200,41 +227,47 @@ def ssdlite320_mobilenet_v3_large( Example: - >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) + >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 6. norm_layer (callable, optional): Module specifying the normalization layer to use. """ + weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the argument.") + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6 + weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 ) - if pretrained: - pretrained_backbone = False - # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper. - reduce_tail = not pretrained_backbone + reduce_tail = weights_backbone is None if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - backbone = mobilenet.mobilenet_v3_large( - pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs + backbone = mobilenet_v3_large( + weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs ) - if not pretrained_backbone: + if weights_backbone is None: # Change the default initialization scheme if not pretrained _normal_init(backbone) backbone = _mobilenet_extractor( @@ -262,7 +295,7 @@ def ssdlite320_mobilenet_v3_large( "image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5], } - kwargs = {**defaults, **kwargs} + kwargs: Any = {**defaults, **kwargs} model = SSD( backbone, anchor_generator, @@ -272,10 +305,7 @@ def ssdlite320_mobilenet_v3_large( **kwargs, ) - if pretrained: - weights_name = "ssdlite320_mobilenet_v3_large_coco" - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index f8238912ffd..b9d3b9b30c9 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -9,14 +9,27 @@ from torch import nn, Tensor from torchvision.ops import StochasticDepth -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible __all__ = [ "EfficientNet", + "EfficientNet_B0_Weights", + "EfficientNet_B1_Weights", + "EfficientNet_B2_Weights", + "EfficientNet_B3_Weights", + "EfficientNet_B4_Weights", + "EfficientNet_B5_Weights", + "EfficientNet_B6_Weights", + "EfficientNet_B7_Weights", + "EfficientNet_V2_S_Weights", + "EfficientNet_V2_M_Weights", + "EfficientNet_V2_L_Weights", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", @@ -31,25 +44,6 @@ ] -model_urls = { - # Weights ported from https://github.com/rwightman/pytorch-image-models/ - "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ - "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - # Weights trained with TorchVision - "efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", - "efficientnet_v2_m": "https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", - # Weights ported from TF - "efficientnet_v2_l": "https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", -} - - @dataclass class _MBConvConfig: expand_ratio: float @@ -362,20 +356,21 @@ def forward(self, x: Tensor) -> Tensor: def _efficientnet( - arch: str, inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, last_channel: Optional[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> EfficientNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) - if pretrained: - if model_urls.get(arch, None) is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model @@ -434,208 +429,484 @@ def _efficientnet_conf( return inverted_residual_setting, last_channel -def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +_COMMON_META = { + "task": "image_classification", + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", +} + + +_COMMON_META_V1 = { + **_COMMON_META, + "architecture": "EfficientNet", + "publication_year": 2019, + "interpolation": InterpolationMode.BICUBIC, + "min_size": (1, 1), +} + + +_COMMON_META_V2 = { + **_COMMON_META, + "architecture": "EfficientNetV2", + "publication_year": 2021, + "interpolation": InterpolationMode.BILINEAR, + "min_size": (33, 33), +} + + +class EfficientNet_B0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 5288548, + "size": (224, 224), + "acc@1": 77.692, + "acc@5": 93.532, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B1_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + transforms=partial( + ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "size": (240, 240), + "acc@1": 78.642, + "acc@5": 94.186, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", + transforms=partial( + ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", + "interpolation": InterpolationMode.BILINEAR, + "size": (240, 240), + "acc@1": 79.838, + "acc@5": 94.934, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class EfficientNet_B2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + transforms=partial( + ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 9109994, + "size": (288, 288), + "acc@1": 80.608, + "acc@5": 95.310, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + transforms=partial( + ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 12233232, + "size": (300, 300), + "acc@1": 82.008, + "acc@5": 96.054, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B4_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + transforms=partial( + ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 19341616, + "size": (380, 380), + "acc@1": 83.384, + "acc@5": 96.594, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + transforms=partial( + ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 30389784, + "size": (456, 456), + "acc@1": 83.444, + "acc@5": 96.628, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B6_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + transforms=partial( + ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 43040704, + "size": (528, 528), + "acc@1": 84.008, + "acc@5": 96.916, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B7_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + transforms=partial( + ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 66347960, + "size": (600, 600), + "acc@1": 84.122, + "acc@5": 96.908, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", + transforms=partial( + ImageClassification, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 21458488, + "size": (384, 384), + "acc@1": 84.228, + "acc@5": 96.878, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_M_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", + transforms=partial( + ImageClassification, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 54139356, + "size": (480, 480), + "acc@1": 85.112, + "acc@5": 97.156, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_L_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", + transforms=partial( + ImageClassification, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BICUBIC, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), + meta={ + **_COMMON_META_V2, + "num_params": 118515272, + "size": (480, 480), + "acc@1": 85.808, + "acc@5": 97.788, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) +def efficientnet_b0( + *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B0 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b0" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.0) - return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B0_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) -def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) +def efficientnet_b1( + *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B1 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B1_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b1" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.1) - return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B1_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) -def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) +def efficientnet_b2( + *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B2 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b2" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.1, depth_mult=1.2) - return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B2_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) -def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) +def efficientnet_b3( + *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B3 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b3" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.2, depth_mult=1.4) - return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B3_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) -def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) +def efficientnet_b4( + *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B4 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B4_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b4" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.4, depth_mult=1.8) - return _efficientnet(arch, inverted_residual_setting, 0.4, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B4_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) + return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) -def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) +def efficientnet_b5( + *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B5 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b5" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.6, depth_mult=2.2) + weights = EfficientNet_B5_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) return _efficientnet( - arch, inverted_residual_setting, 0.4, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) +def efficientnet_b6( + *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B6 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B6_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b6" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.8, depth_mult=2.6) + weights = EfficientNet_B6_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) return _efficientnet( - arch, inverted_residual_setting, 0.5, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) +def efficientnet_b7( + *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B7 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B7_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b7" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=2.0, depth_mult=3.1) + weights = EfficientNet_B7_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) return _efficientnet( - arch, inverted_residual_setting, 0.5, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_v2_s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) +def efficientnet_v2_s( + *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-S architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_S_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_s" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_S_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") return _efficientnet( - arch, inverted_residual_setting, 0.2, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, ) -def efficientnet_v2_m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) +def efficientnet_v2_m( + *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-M architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_M_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_m" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_M_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") return _efficientnet( - arch, inverted_residual_setting, 0.3, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, ) -def efficientnet_v2_l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) +def efficientnet_v2_l( + *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-L architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_L_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_l" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_L_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") return _efficientnet( - arch, inverted_residual_setting, 0.4, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index f3487b44c09..ced92571974 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,5 +1,6 @@ import warnings from collections import namedtuple +from functools import partial from typing import Optional, Tuple, List, Callable, Any import torch @@ -7,15 +8,15 @@ import torch.nn.functional as F from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"] -model_urls = { - # GoogLeNet ported from TensorFlow - "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", -} +__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] + GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]) GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]} @@ -274,38 +275,62 @@ def forward(self, x: Tensor) -> Tensor: return F.relu(x, inplace=True) -def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> GoogLeNet: +class GoogLeNet_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/googlenet-1378be20.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "GoogLeNet", + "publication_year": 2014, + "num_params": 6624904, + "size": (224, 224), + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", + "acc@1": 69.778, + "acc@5": 89.530, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) +def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: r"""GoogLeNet (Inception v1) model architecture from `"Going Deeper with Convolutions" `_. The required minimum input size of the model is 15x15. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (GoogLeNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. + was trained on ImageNet. Default: True if ``weights=GoogLeNet_Weights.IMAGENET1K_V1``, else False. """ - if pretrained: + weights = GoogLeNet_Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" not in kwargs: - kwargs["aux_logits"] = False - if kwargs["aux_logits"]: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - kwargs["init_weights"] = False - model = GoogLeNet(**kwargs) - state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress) - model.load_state_dict(state_dict) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = GoogLeNet(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.aux1 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment] - return model + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) - return GoogLeNet(**kwargs) + return model diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 0fe6400a681..816fab45549 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -1,23 +1,22 @@ import warnings from collections import namedtuple +from functools import partial from typing import Callable, Any, Optional, Tuple, List import torch import torch.nn.functional as F from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"] +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] -model_urls = { - # Inception v3 ported from TensorFlow - "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", -} - InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"]) InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]} @@ -408,7 +407,29 @@ def forward(self, x: Tensor) -> Tensor: return F.relu(x, inplace=True) -def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3: +class Inception_V3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + transforms=partial(ImageClassification, crop_size=299, resize_size=342), + meta={ + "task": "image_classification", + "architecture": "InceptionV3", + "publication_year": 2015, + "num_params": 27161264, + "size": (299, 299), + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", + "acc@1": 77.294, + "acc@5": 93.450, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) +def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" `_. The required minimum input size of the model is 75x75. @@ -418,28 +439,29 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) N x 3 x 299 x 299, so ensure your images are sized accordingly. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Inception_V3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. + was trained on ImageNet. Default: True if ``weights=Inception_V3_Weights.IMAGENET1K_V1``, else False. """ - if pretrained: + weights = Inception_V3_Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", True) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = True - kwargs["init_weights"] = False # we are loading weights from a pretrained model - model = Inception3(**kwargs) - state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress) - model.load_state_dict(state_dict) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = Inception3(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.AuxLogits = None - return model - return Inception3(**kwargs) + return model diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 9608c555a88..578e77f7934 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,21 +1,30 @@ import warnings -from typing import Any, Dict, List +from functools import partial +from typing import Any, Dict, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"] -_MODEL_URLS = { - "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - "mnasnet0_75": None, - "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - "mnasnet1_3": None, -} +__all__ = [ + "MNASNet", + "MNASNet0_5_Weights", + "MNASNet0_75_Weights", + "MNASNet1_0_Weights", + "MNASNet1_3_Weights", + "mnasnet0_5", + "mnasnet0_75", + "mnasnet1_0", + "mnasnet1_3", +] + # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # 1.0 - tensorflow. @@ -202,68 +211,123 @@ def _load_from_state_dict( ) -def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: - if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: - raise ValueError(f"No checkpoint is available for model type {model_name}") - checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) +_COMMON_META = { + "task": "image_classification", + "architecture": "MNASNet", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/1e100/mnasnet_trainer", +} + + +class MNASNet0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2218512, + "acc@1": 67.734, + "acc@5": 87.490, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet0_75_Weights(WeightsEnum): + # If a default model is added here the corresponding changes need to be done in mnasnet0_75 + pass + + +class MNASNet1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4383312, + "acc@1": 73.456, + "acc@5": 91.510, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet1_3_Weights(WeightsEnum): + # If a default model is added here the corresponding changes need to be done in mnasnet1_3 + pass + + +def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MNASNet(alpha, **kwargs) -def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) +def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 0.5 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet0_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(0.5, **kwargs) - if pretrained: - _load_pretrained("mnasnet0_5", model, progress) - return model + weights = MNASNet0_5_Weights.verify(weights) + + return _mnasnet(0.5, weights, progress, **kwargs) -def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 0.75 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet0_75_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(0.75, **kwargs) - if pretrained: - _load_pretrained("mnasnet0_75", model, progress) - return model + weights = MNASNet0_75_Weights.verify(weights) + + return _mnasnet(0.75, weights, progress, **kwargs) -def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: +@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) +def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 1.0 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(1.0, **kwargs) - if pretrained: - _load_pretrained("mnasnet1_0", model, progress) - return model + weights = MNASNet1_0_Weights.verify(weights) + return _mnasnet(1.0, weights, progress, **kwargs) -def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: + +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 1.3 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet1_3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(1.3, **kwargs) - if pretrained: - _load_pretrained("mnasnet1_3", model, progress) - return model + weights = MNASNet1_3_Weights.verify(weights) + + return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 4108305d3f5..0a270d14d3a 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1,4 +1,6 @@ -from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all __all__ = mv2_all + mv3_all diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index f65993b0a5a..085049117ec 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,22 +1,20 @@ import warnings +from functools import partial from typing import Callable, Any, Optional, List import torch from torch import Tensor from torch import nn -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible -__all__ = ["MobileNetV2", "mobilenet_v2"] - - -model_urls = { - "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", -} +__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] # necessary for backwards compatibility @@ -196,17 +194,62 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: +_COMMON_META = { + "task": "image_classification", + "architecture": "MobileNetV2", + "publication_year": 2018, + "num_params": 3504872, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class MobileNet_V2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", + "acc@1": 71.878, + "acc@5": 90.286, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "acc@1": 72.154, + "acc@5": 90.822, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) +def mobilenet_v2( + *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV2: """ Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = MobileNet_V2_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MobileNetV2(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["mobilenet_v2"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 530467d6d53..91e1ea91a94 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -5,19 +5,21 @@ import torch from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible -__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] - - -model_urls = { - "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", -} +__all__ = [ + "MobileNetV3", + "MobileNet_V3_Large_Weights", + "MobileNet_V3_Small_Weights", + "mobilenet_v3_large", + "mobilenet_v3_small", +] class SqueezeExcitation(SElayer): @@ -284,45 +286,106 @@ def _mobilenet_v3_conf( def _mobilenet_v3( - arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, -): +) -> MobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) - if pretrained: - if model_urls.get(arch, None) is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: +_COMMON_META = { + "task": "image_classification", + "architecture": "MobileNetV3", + "publication_year": 2019, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class MobileNet_V3_Large_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 5483032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 74.042, + "acc@5": 91.340, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 5483032, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "acc@1": 75.274, + "acc@5": 92.566, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class MobileNet_V3_Small_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2542856, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 67.668, + "acc@5": 87.402, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) +def mobilenet_v3_large( + *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: """ Constructs a large MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "mobilenet_v3_large" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + weights = MobileNet_V3_Large_Weights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) -def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) +def mobilenet_v3_small( + *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: """ Constructs a small MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V3_Small_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "mobilenet_v3_small" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + weights = MobileNet_V3_Small_Weights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) diff --git a/torchvision/models/optical_flow/__init__.py b/torchvision/models/optical_flow/__init__.py index 9dd32f25dec..89d2302f825 100644 --- a/torchvision/models/optical_flow/__init__.py +++ b/torchvision/models/optical_flow/__init__.py @@ -1 +1 @@ -from .raft import RAFT, raft_large, raft_small +from .raft import * diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 00200529f66..244d2b2fac1 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -8,8 +8,10 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms._presets import OpticalFlow, InterpolationMode from ...utils import _log_api_usage_once +from .._api import Weights, WeightsEnum +from .._utils import handle_legacy_interface from ._utils import grid_sample, make_coords_grid, upsample_flow @@ -17,15 +19,11 @@ "RAFT", "raft_large", "raft_small", + "Raft_Large_Weights", + "Raft_Small_Weights", ) -_MODELS_URLS = { - "raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", -} - - class ResidualBlock(nn.Module): """Slightly modified Residual block with extra relu and biases.""" @@ -513,10 +511,139 @@ def forward(self, image1, image2, num_flow_updates: int = 12): return flow_predictions +_COMMON_META = { + "task": "optical_flow", + "architecture": "RAFT", + "publication_year": 2020, + "interpolation": InterpolationMode.BILINEAR, +} + + +class Raft_Large_Weights(WeightsEnum): + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-things.pth) + url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 1.4411, + "sintel_train_finalpass_epe": 2.7894, + "kitti_train_per_image_epe": 5.0172, + "kitti_train_f1-all": 17.4506, + }, + ) + + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_train_cleanpass_epe": 1.3822, + "sintel_train_finalpass_epe": 2.7161, + "kitti_train_per_image_epe": 4.5118, + "kitti_train_f1-all": 16.0679, + }, + ) + + C_T_SKHT_V1 = Weights( + # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_test_cleanpass_epe": 1.94, + "sintel_test_finalpass_epe": 3.18, + }, + ) + + C_T_SKHT_V2 = Weights( + # Chairs + Things + Sintel fine-tuning, i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_test_cleanpass_epe": 1.819, + "sintel_test_finalpass_epe": 3.067, + }, + ) + + C_T_SKHT_K_V1 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "kitti_test_f1-all": 5.10, + }, + ) + + C_T_SKHT_K_V2 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti + # Same as CT_SKHT with extra fine-tuning on Kitti + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "kitti_test_f1-all": 5.19, + }, + ) + + DEFAULT = C_T_SKHT_V2 + + +class Raft_Small_Weights(WeightsEnum): + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-small.pth) + url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 2.1231, + "sintel_train_finalpass_epe": 3.2790, + "kitti_train_per_image_epe": 7.6557, + "kitti_train_f1-all": 25.2801, + }, + ) + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_train_cleanpass_epe": 1.9901, + "sintel_train_finalpass_epe": 3.2831, + "kitti_train_per_image_epe": 7.5978, + "kitti_train_f1-all": 25.2369, + }, + ) + + DEFAULT = C_T_V2 + + def _raft( *, - arch=None, - pretrained=False, + weights=None, progress=False, # Feature encoder feature_encoder_layers, @@ -590,38 +717,34 @@ def _raft( mask_predictor=mask_predictor, **kwargs, # not really needed, all params should be consumed by now ) - if pretrained: - state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def raft_large(*, pretrained=False, progress=True, **kwargs): +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Please see the example below for a tutorial on how to use this model. Args: - pretrained (bool): Whether to use weights that have been pre-trained on - :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D` - with two fine-tuning steps: - - - one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D` - - one on :class:`~torchvsion.datasets.KittiFlow`. - - This corresponds to the ``C+T+S/K`` strategy in the paper. - - progress (bool): If True, displays a progress bar of the download to stderr. + weights(Raft_Large_weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. Returns: - nn.Module: The model. + RAFT: The model. """ + weights = Raft_Large_Weights.verify(weights) + return _raft( - arch="raft_large", - pretrained=pretrained, + weights=weights, progress=progress, # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), @@ -650,25 +773,27 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): ) -def raft_small(*, pretrained=False, progress=True, **kwargs): +@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT "small" model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Please see the example below for a tutorial on how to use this model. Args: - pretrained (bool): Whether to use weights that have been pre-trained on - :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`. + weights(Raft_Small_weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. Returns: - nn.Module: The model. + RAFT: The model. """ + weights = Raft_Small_Weights.verify(weights) return _raft( - arch="raft_small", - pretrained=pretrained, + weights=weights, progress=progress, # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), diff --git a/torchvision/models/quantization/__init__.py b/torchvision/models/quantization/__init__.py index deae997a219..da8bbba3567 100644 --- a/torchvision/models/quantization/__init__.py +++ b/torchvision/models/quantization/__init__.py @@ -1,5 +1,5 @@ -from .mobilenet import * -from .resnet import * from .googlenet import * from .inception import * +from .mobilenet import * +from .resnet import * from .shufflenetv2 import * diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 98d9382214f..1794c834eea 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -1,22 +1,25 @@ import warnings -from typing import Any, Optional +from functools import partial +from typing import Any, Optional, Union import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F -from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms._presets import ImageClassification, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableGoogLeNet", "googlenet"] - -quant_model_urls = { - # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch - "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", -} +__all__ = [ + "QuantizableGoogLeNet", + "GoogLeNet_QuantizedWeights", + "googlenet", +] class QuantizableBasicConv2d(BasicConv2d): @@ -103,8 +106,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class GoogLeNet_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "GoogLeNet", + "publication_year": 2014, + "num_params": 6624904, + "size": (224, 224), + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "Post Training Quantization", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, + "acc@1": 69.826, + "acc@5": 89.404, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else GoogLeNet_Weights.IMAGENET1K_V1, + ) +) def googlenet( - pretrained: bool = False, + *, + weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -117,49 +153,38 @@ def googlenet( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, adds two auxiliary branches that can improve training. - Default: *False* when pretrained is True otherwise *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" not in kwargs: - kwargs["aux_logits"] = False - if kwargs["aux_logits"]: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - kwargs["init_weights"] = False + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableGoogLeNet(**kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls["googlenet_" + backend] - else: - model_url = model_urls["googlenet"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.aux1 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment] + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) + return model diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 27d021428b9..ff5c9a37365 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -1,29 +1,28 @@ import warnings -from typing import Any, List, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torchvision.models import inception as inception_module -from torchvision.models.inception import InceptionOutputs +from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms._presets import ImageClassification, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model __all__ = [ "QuantizableInception3", + "Inception_V3_QuantizedWeights", "inception_v3", ] -quant_model_urls = { - # fp32 weights ported from TensorFlow, quantized in PyTorch - "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" -} - - class QuantizableBasicConv2d(inception_module.BasicConv2d): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -173,8 +172,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class Inception_V3_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", + transforms=partial(ImageClassification, crop_size=299, resize_size=342), + meta={ + "task": "image_classification", + "architecture": "InceptionV3", + "publication_year": 2015, + "num_params": 27161264, + "size": (299, 299), + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "Post Training Quantization", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": Inception_V3_Weights.IMAGENET1K_V1, + "acc@1": 77.176, + "acc@5": 93.354, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else Inception_V3_Weights.IMAGENET1K_V1, + ) +) def inception_v3( - pretrained: bool = False, + *, + weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -191,48 +223,35 @@ def inception_v3( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Inception_V3_QuantizedWeights or Inception_V3_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, add an auxiliary branch that can improve training. - Default: *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = False + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableInception3(**kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - model_url = quant_model_urls["inception_v3_google_" + backend] - else: - model_url = inception_module.model_urls["inception_v3_google"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + if quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + model.load_state_dict(weights.get_state_dict(progress=progress)) + if not quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None - if not quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None return model diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index 8f2c42db640..0a270d14d3a 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -1,4 +1,6 @@ -from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, __all__ as mv3_all +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all __all__ = mv2_all + mv3_all diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 8cd9f16d13e..d9554e0ba9f 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,20 +1,24 @@ -from typing import Any, Optional +from functools import partial +from typing import Any, Optional, Union from torch import Tensor from torch import nn from torch.ao.quantization import QuantStub, DeQuantStub -from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation +from ...transforms._presets import ImageClassification, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] - -quant_model_urls = { - "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" -} +__all__ = [ + "QuantizableMobileNetV2", + "MobileNet_V2_QuantizedWeights", + "mobilenet_v2", +] class QuantizableInvertedResidual(InvertedResidual): @@ -60,8 +64,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class MobileNet_V2_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "MobileNetV2", + "publication_year": 2018, + "num_params": 3504872, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "Quantization Aware Training", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", + "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, + "acc@1": 71.658, + "acc@5": 90.150, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V2_Weights.IMAGENET1K_V1, + ) +) def mobilenet_v2( - pretrained: bool = False, + *, + weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -76,27 +113,25 @@ def mobilenet_v2( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. - progress (bool): If True, displays a progress bar of the download to stderr - quantize(bool): If True, returns a quantized model, else returns a float model + weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize(bool): If True, returns a quantized model, else returns a float model """ + weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") + model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "qnnpack" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls["mobilenet_v2_" + backend] - else: - model_url = model_urls["mobilenet_v2"] - state_dict = load_state_dict_from_url(model_url, progress=progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - model.load_state_dict(state_dict) return model diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4d7e2f7baad..88907ec210a 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,20 +1,30 @@ -from typing import Any, List, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch from torch import nn, Tensor from torch.ao.quantization import QuantStub, DeQuantStub -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf +from ...transforms._presets import ImageClassification, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..mobilenetv3 import ( + InvertedResidual, + InvertedResidualConfig, + MobileNetV3, + _mobilenet_v3_conf, + MobileNet_V3_Large_Weights, +) from .utils import _fuse_modules, _replace_relu -__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] - -quant_model_urls = { - "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", -} +__all__ = [ + "QuantizableMobileNetV3", + "MobileNet_V3_Large_QuantizedWeights", + "mobilenet_v3_large", +] class QuantizableSqueezeExcitation(SqueezeExcitation): @@ -112,47 +122,73 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) -def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None: - if model_url is None: - raise ValueError(f"No checkpoint is available for {arch}") - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) - - def _mobilenet_v3_model( - arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, ) -> QuantizableMobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) if quantize: - backend = "qnnpack" - model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) torch.ao.quantization.prepare_qat(model, inplace=True) - if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if quantize: torch.ao.quantization.convert(model, inplace=True) model.eval() - else: - if pretrained: - _load_weights(arch, model, model_urls.get(arch, None), progress) return model +class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "MobileNetV3", + "publication_year": 2019, + "num_params": 5483032, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "Quantization Aware Training", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", + "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, + "acc@1": 73.004, + "acc@5": 90.858, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V3_Large_Weights.IMAGENET1K_V1, + ) +) def mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -166,10 +202,12 @@ def mobilenet_v3_large( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. - progress (bool): If True, displays a progress bar of the download to stderr - quantize (bool): If True, returns a quantized model, else returns a float model + weights (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, returns a quantized model, else returns a float model """ - arch = "mobilenet_v3_large" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs) + weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index f55aa0e103c..a781f320000 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,21 +1,34 @@ +from functools import partial from typing import Any, Type, Union, List, Optional import torch import torch.nn as nn from torch import Tensor -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls - -from ..._internally_replaced_utils import load_state_dict_from_url +from torchvision.models.resnet import ( + Bottleneck, + BasicBlock, + ResNet, + ResNet18_Weights, + ResNet50_Weights, + ResNeXt101_32X8D_Weights, +) + +from ...transforms._presets import ImageClassification, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"] - -quant_model_urls = { - "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - "resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", -} +__all__ = [ + "QuantizableResNet", + "ResNet18_QuantizedWeights", + "ResNet50_QuantizedWeights", + "ResNeXt101_32X8D_QuantizedWeights", + "resnet18", + "resnet50", + "resnext101_32x8d", +] class QuantizableBasicBlock(BasicBlock): @@ -109,38 +122,130 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: def _resnet( - arch: str, block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], layers: List[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, ) -> QuantizableResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableResNet(block, layers, **kwargs) _replace_relu(model) if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - if pretrained: - if quantize: - model_url = quant_model_urls[arch + "_" + backend] - else: - model_url = model_urls[arch] + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - state_dict = load_state_dict_from_url(model_url, progress=progress) - - model.load_state_dict(state_dict) return model +_COMMON_META = { + "task": "image_classification", + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "Post Training Quantization", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", +} + + +class ResNet18_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 11689512, + "unquantized": ResNet18_Weights.IMAGENET1K_V1, + "acc@1": 69.494, + "acc@5": 88.882, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ResNet50_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V1, + "acc@1": 75.920, + "acc@5": 92.814, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V2, + "acc@1": 80.282, + "acc@5": 94.976, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + "acc@1": 78.986, + "acc@5": 94.480, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, + "acc@1": 82.574, + "acc@5": 96.132, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet18_Weights.IMAGENET1K_V1, + ) +) def resnet18( - pretrained: bool = False, + *, + weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -149,33 +254,56 @@ def resnet18( `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs) + weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) + + return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet50_Weights.IMAGENET1K_V1, + ) +) def resnet50( - pretrained: bool = False, + *, + weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableResNet: - r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) + weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) + return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + ) +) def resnext101_32x8d( - pretrained: bool = False, + *, + weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -184,10 +312,13 @@ def resnext101_32x8d( `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) + weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 9d25315ffa0..1f4f1890e07 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -1,24 +1,27 @@ -from typing import Any, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch import torch.nn as nn from torch import Tensor from torchvision.models import shufflenetv2 -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms._presets import ImageClassification, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights from .utils import _fuse_modules, _replace_relu, quantize_model + __all__ = [ "QuantizableShuffleNetV2", + "ShuffleNet_V2_X0_5_QuantizedWeights", + "ShuffleNet_V2_X1_0_QuantizedWeights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", ] -quant_model_urls = { - "shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", -} - class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -73,39 +76,86 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: def _shufflenetv2( - arch: str, - pretrained: bool, + stages_repeats: List[int], + stages_out_channels: List[int], + *, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, - *args: Any, **kwargs: Any, ) -> QuantizableShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") - model = QuantizableShuffleNetV2(*args, **kwargs) + model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - if pretrained: - model_url: Optional[str] = None - if quantize: - model_url = quant_model_urls[arch + "_" + backend] - else: - model_url = shufflenetv2.model_urls[arch] - - state_dict = load_state_dict_from_url(model_url, progress=progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - model.load_state_dict(state_dict) return model +_COMMON_META = { + "task": "image_classification", + "architecture": "ShuffleNetV2", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "Post Training Quantization", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", +} + + +class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + "acc@1": 57.972, + "acc@5": 79.780, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + "acc@1": 68.360, + "acc@5": 87.582, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + ) +) def shufflenet_v2_x0_5( - pretrained: bool = False, + *, + weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -116,17 +166,28 @@ def shufflenet_v2_x0_5( `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ + weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) return _shufflenetv2( - "shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs + [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs ) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + ) +) def shufflenet_v2_x1_0( - pretrained: bool = False, + *, + weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -137,10 +198,12 @@ def shufflenet_v2_x1_0( `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ + weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) return _shufflenetv2( - "shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs + [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs ) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 74abd20b237..72093686d84 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -1,8 +1,3 @@ -# Modified from -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/anynet.py -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py - - import math from collections import OrderedDict from functools import partial @@ -11,14 +6,31 @@ import torch from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible __all__ = [ "RegNet", + "RegNet_Y_400MF_Weights", + "RegNet_Y_800MF_Weights", + "RegNet_Y_1_6GF_Weights", + "RegNet_Y_3_2GF_Weights", + "RegNet_Y_8GF_Weights", + "RegNet_Y_16GF_Weights", + "RegNet_Y_32GF_Weights", + "RegNet_Y_128GF_Weights", + "RegNet_X_400MF_Weights", + "RegNet_X_800MF_Weights", + "RegNet_X_1_6GF_Weights", + "RegNet_X_3_2GF_Weights", + "RegNet_X_8GF_Weights", + "RegNet_X_16GF_Weights", + "RegNet_X_32GF_Weights", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", @@ -37,24 +49,6 @@ ] -model_urls = { - "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", -} - - class SimpleStemIN(Conv2dNormActivation): """Simple stem for ImageNet: 3x3, BN, ReLU.""" @@ -390,219 +384,652 @@ def forward(self, x: Tensor) -> Tensor: return x -def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet: +def _regnet( + block_params: BlockParams, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> RegNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) model = RegNet(block_params, norm_layer=norm_layer, **kwargs) - if pretrained: - if arch not in model_urls: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "RegNet", + "publication_year": 2020, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class RegNet_Y_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 74.046, + "acc@5": 91.716, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 75.804, + "acc@5": 92.742, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 76.420, + "acc@5": 93.136, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 78.828, + "acc@5": 94.502, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 77.950, + "acc@5": 93.966, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 80.876, + "acc@5": 95.444, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 78.948, + "acc@5": 94.576, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.982, + "acc@5": 95.972, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 80.032, + "acc@5": 95.048, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.828, + "acc@5": 96.330, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.424, + "acc@5": 95.240, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.886, + "acc@5": 96.328, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.878, + "acc@5": 95.340, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 83.368, + "acc@5": 96.498, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_128GF_Weights(WeightsEnum): + # weights are not available yet. + pass + + +class RegNet_X_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 72.834, + "acc@5": 90.950, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 74.864, + "acc@5": 92.322, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 75.212, + "acc@5": 92.348, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 77.522, + "acc@5": 93.826, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 77.040, + "acc@5": 93.440, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 79.668, + "acc@5": 94.922, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 78.364, + "acc@5": 93.992, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.196, + "acc@5": 95.430, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 79.344, + "acc@5": 94.686, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.682, + "acc@5": 95.678, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 80.058, + "acc@5": 94.944, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.716, + "acc@5": 96.196, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.622, + "acc@5": 95.248, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 83.014, + "acc@5": 96.288, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) +def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_400MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_400MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_400MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) - return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) +def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_800MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_800MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_800MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) - return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_1.6GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_1_6GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_1_6GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_3.2GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_3_2GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_3_2GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) +def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_8GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_8GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_8GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) +def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_16GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_16GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_16GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) +def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_32GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_32GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_32GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", None)) +def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_128GF architecture from `"Designing Network Design Spaces" `_. NOTE: Pretrained weights are not available for this model. + + Args: + weights (RegNet_Y_128GF_Weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_128GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) +def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_400MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_400MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_400MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) - return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) +def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_800MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_800MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_800MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) - return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_1.6GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_1_6GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) - return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_3.2GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_3_2GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) - return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) +def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_8GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_8GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) - return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) +def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_16GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_16GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) - return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) +def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_32GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) - return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs) + weights = RegNet_X_32GF_Weights.verify(weights) - -# TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF + params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) + return _regnet(params, weights, progress, **kwargs) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index b0bb8d13ade..8f44e553296 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,15 +1,28 @@ +from functools import partial from typing import Type, Any, Callable, Union, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "ResNet", + "ResNet18_Weights", + "ResNet34_Weights", + "ResNet50_Weights", + "ResNet101_Weights", + "ResNet152_Weights", + "ResNeXt50_32X4D_Weights", + "ResNeXt101_32X8D_Weights", + "Wide_ResNet50_2_Weights", + "Wide_ResNet101_2_Weights", "resnet18", "resnet34", "resnet50", @@ -22,19 +35,6 @@ ] -model_urls = { - "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", - "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", - "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", - "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", - "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", - "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", -} - - def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( @@ -284,102 +284,386 @@ def forward(self, x: Tensor) -> Tensor: def _resnet( - arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ResNet(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +_COMMON_META = { + "task": "image_classification", + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class ResNet18_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet18-f37072fd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 11689512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 69.758, + "acc@5": 89.078, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet34_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet34-b627a593.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 21797672, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 73.314, + "acc@5": 91.420, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet50_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet50-0676ba61.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 76.130, + "acc@5": 92.862, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", + "acc@1": 80.858, + "acc@5": 95.434, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet101_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet101-63fe2227.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 77.374, + "acc@5": 93.546, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.886, + "acc@5": 95.780, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet152_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet152-394f9c45.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 78.312, + "acc@5": 94.046, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet152-f82ba261.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.284, + "acc@5": 96.002, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt50_32X4D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "acc@1": 77.618, + "acc@5": 93.698, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.198, + "acc@5": 95.340, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt101_32X8D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "acc@1": 79.312, + "acc@5": 94.526, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 82.834, + "acc@5": 96.228, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet50_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "acc@1": 78.468, + "acc@5": 94.086, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 81.602, + "acc@5": 95.758, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet101_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "acc@1": 78.848, + "acc@5": 94.284, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.510, + "acc@5": 96.020, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) +def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + weights = ResNet18_Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) -def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) +def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet34_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNet34_Weights.verify(weights) + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) +def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNet50_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) +def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = ResNet101_Weights.verify(weights) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) +def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet152_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) + weights = ResNet152_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) -def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) +def resnext50_32x4d( + *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt50_32X4D_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 4 - return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNeXt50_32X4D_Weights.verify(weights) + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) +def resnext101_32x8d( + *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt101_32X8D_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = ResNeXt101_32X8D_Weights.verify(weights) + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) +def wide_resnet50_2( + *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_. @@ -389,14 +673,19 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Wide_ResNet50_2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = Wide_ResNet50_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) +def wide_resnet101_2( + *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_. @@ -406,8 +695,10 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Wide_ResNet101_2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = Wide_ResNet101_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/models/segmentation/__init__.py b/torchvision/models/segmentation/__init__.py index 1765502d693..3d6f37f958a 100644 --- a/torchvision/models/segmentation/__init__.py +++ b/torchvision/models/segmentation/__init__.py @@ -1,3 +1,3 @@ -from .fcn import * from .deeplabv3 import * +from .fcn import * from .lraspp import * diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 0bbea5d3e81..44a60a95c54 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -4,7 +4,6 @@ from torch import nn, Tensor from torch.nn import functional as F -from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once @@ -36,10 +35,3 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: result["aux"] = x return result - - -def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None: - if model_url is None: - raise ValueError(f"No checkpoint is available for {arch}") - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 15ab5fffa5e..092a81f643b 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,31 +1,31 @@ -from typing import List, Optional +from functools import partial +from typing import Any, List, Optional import torch from torch import nn from torch.nn import functional as F -from .. import mobilenetv3 -from .. import resnet -from .._utils import IntermediateLayerGetter -from ._utils import _SimpleSegmentationModel, _load_weights +from ...transforms._presets import SemanticSegmentation, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights +from ._utils import _SimpleSegmentationModel from .fcn import FCNHead __all__ = [ "DeepLabV3", + "DeepLabV3_ResNet50_Weights", + "DeepLabV3_ResNet101_Weights", + "DeepLabV3_MobileNet_V3_Large_Weights", + "deeplabv3_mobilenet_v3_large", "deeplabv3_resnet50", "deeplabv3_resnet101", - "deeplabv3_mobilenet_v3_large", ] -model_urls = { - "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", -} - - class DeepLabV3(_SimpleSegmentationModel): """ Implements DeepLabV3 model from @@ -114,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _deeplabv3_resnet( - backbone: resnet.ResNet, + backbone: ResNet, num_classes: int, aux: Optional[bool], ) -> DeepLabV3: @@ -128,8 +128,62 @@ def _deeplabv3_resnet( return DeepLabV3(backbone, classifier, aux_classifier) +_COMMON_META = { + "task": "image_semantic_segmentation", + "architecture": "DeepLabV3", + "publication_year": 2017, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class DeepLabV3_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 42004074, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", + "mIoU": 66.4, + "acc": 92.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 60996202, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", + "mIoU": 67.4, + "acc": 92.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 11029328, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", + "mIoU": 60.3, + "acc": 91.2, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + def _deeplabv3_mobilenetv3( - backbone: mobilenetv3.MobileNetV3, + backbone: MobileNetV3, num_classes: int, aux: Optional[bool], ) -> DeepLabV3: @@ -151,91 +205,124 @@ def _deeplabv3_mobilenetv3( return DeepLabV3(backbone, classifier, aux_classifier) +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def deeplabv3_resnet50( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_ResNet50_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) - backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_resnet50_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) def deeplabv3_resnet101( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_ResNet101_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): The number of classes aux_loss (bool, optional): If True, include an auxiliary classifier - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) - backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_resnet101_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def deeplabv3_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_mobilenet_v3_large_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 307781ebf00..6b6d14ffe32 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -1,19 +1,17 @@ -from typing import Optional +from functools import partial +from typing import Any, Optional from torch import nn -from .. import resnet -from .._utils import IntermediateLayerGetter -from ._utils import _SimpleSegmentationModel, _load_weights +from ...transforms._presets import SemanticSegmentation, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet, ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from ._utils import _SimpleSegmentationModel -__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"] - - -model_urls = { - "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", -} +__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] class FCN(_SimpleSegmentationModel): @@ -49,8 +47,47 @@ def __init__(self, in_channels: int, channels: int) -> None: super().__init__(*layers) +_COMMON_META = { + "task": "image_semantic_segmentation", + "architecture": "FCN", + "publication_year": 2014, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class FCN_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 35322218, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", + "mIoU": 60.5, + "acc": 91.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class FCN_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 54314346, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", + "mIoU": 63.7, + "acc": 91.9, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + def _fcn_resnet( - backbone: resnet.ResNet, + backbone: ResNet, num_classes: int, aux: Optional[bool], ) -> FCN: @@ -64,61 +101,83 @@ def _fcn_resnet( return FCN(backbone, classifier, aux_classifier) +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fcn_resnet50( - pretrained: bool = False, + *, + weights: Optional[FCN_ResNet50_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (FCN_ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = FCN_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 - backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _fcn_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "fcn_resnet50_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) def fcn_resnet101( - pretrained: bool = False, + *, + weights: Optional[FCN_ResNet101_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, + **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (FCN_ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = FCN_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 - backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _fcn_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "fcn_resnet101_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index ca73140661b..fc6d14d366b 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -1,21 +1,19 @@ from collections import OrderedDict -from typing import Any, Dict +from functools import partial +from typing import Any, Dict, Optional from torch import nn, Tensor from torch.nn import functional as F +from ...transforms._presets import SemanticSegmentation, InterpolationMode from ...utils import _log_api_usage_once -from .. import mobilenetv3 -from .._utils import IntermediateLayerGetter -from ._utils import _load_weights +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large -__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"] - - -model_urls = { - "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", -} +__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] class LRASPP(nn.Module): @@ -30,7 +28,7 @@ class LRASPP(nn.Module): "high" for the high level feature map and "low" for the low level feature map. low_channels (int): the number of channels of the low level features. high_channels (int): the number of channels of the high level features. - num_classes (int): number of output classes of the model (including the background). + num_classes (int, optional): number of output classes of the model (including the background). inter_channels (int, optional): the number of channels for intermediate computations. """ @@ -81,7 +79,7 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor: return self.low_classifier(low) + self.high_classifier(x) -def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP: +def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. @@ -95,31 +93,61 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> return LRASPP(backbone, low_channels, high_channels, num_classes) +class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + "task": "image_semantic_segmentation", + "architecture": "LRASPP", + "publication_year": 2019, + "num_params": 3221538, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", + "mIoU": 57.9, + "acc": 91.2, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +@handle_legacy_interface( + weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def lraspp_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 21, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, **kwargs: Any, ) -> LRASPP: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (LRASPP_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, the backbone will be pre-trained. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone """ if kwargs.pop("aux_loss", False): raise NotImplementedError("This model does not use auxiliary loss") - if pretrained: - pretrained_backbone = False - backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) model = _lraspp_mobilenetv3(backbone, num_classes) - if pretrained: - arch = "lraspp_mobilenet_v3_large_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 9a893ba1510..e988b819078 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,21 +1,28 @@ -from typing import Callable, Any, List +from functools import partial +from typing import Callable, Any, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] - -model_urls = { - "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - "shufflenetv2_x1.5": None, - "shufflenetv2_x2.0": None, -} +__all__ = [ + "ShuffleNetV2", + "ShuffleNet_V2_X0_5_Weights", + "ShuffleNet_V2_X1_0_Weights", + "ShuffleNet_V2_X1_5_Weights", + "ShuffleNet_V2_X2_0_Weights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] def channel_shuffle(x: Tensor, groups: int) -> Tensor: @@ -159,67 +166,138 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: +def _shufflenetv2( + weights: Optional[WeightsEnum], + progress: bool, + *args: Any, + **kwargs: Any, +) -> ShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ShuffleNetV2(*args, **kwargs) - if pretrained: - model_url = model_urls[arch] - if model_url is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +_COMMON_META = { + "task": "image_classification", + "architecture": "ShuffleNetV2", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0", +} + + +class ShuffleNet_V2_X0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "acc@1": 69.362, + "acc@5": 88.316, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "acc@1": 60.552, + "acc@5": 81.746, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_5_Weights(WeightsEnum): + pass + + +class ShuffleNet_V2_X2_0_Weights(WeightsEnum): + pass + + +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x0_5( + *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) -def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x1_0( + *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) -def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +@handle_legacy_interface(weights=("pretrained", None)) +def shufflenet_v2_x1_5( + *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + weights = ShuffleNet_V2_X1_5_Weights.verify(weights) + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) -def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + +@handle_legacy_interface(weights=("pretrained", None)) +def shufflenet_v2_x2_0( + *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) + weights = ShuffleNet_V2_X2_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 2c1a30f225d..bde8b5efcfd 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -1,18 +1,18 @@ -from typing import Any +from functools import partial +from typing import Any, Optional import torch import torch.nn as nn import torch.nn.init as init -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"] -model_urls = { - "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", -} +__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] class Fire(nn.Module): @@ -97,29 +97,85 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.flatten(x, 1) -def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: +def _squeezenet( + version: str, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SqueezeNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = SqueezeNet(version, **kwargs) - if pretrained: - arch = "squeezenet" + version - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "SqueezeNet", + "publication_year": 2016, + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", +} + + +class SqueezeNet1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "min_size": (21, 21), + "num_params": 1248424, + "acc@1": 58.092, + "acc@5": 80.420, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class SqueezeNet1_1_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "min_size": (17, 17), + "num_params": 1235496, + "acc@1": 58.178, + "acc@5": 80.624, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1)) +def squeezenet1_0( + *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" `_ paper. The required minimum input size of the model is 21x21. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (SqueezeNet1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet("1_0", pretrained, progress, **kwargs) + weights = SqueezeNet1_0_Weights.verify(weights) + return _squeezenet("1_0", weights, progress, **kwargs) -def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1)) +def squeezenet1_1( + *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: r"""SqueezeNet 1.1 model from the `official SqueezeNet repo `_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters @@ -127,7 +183,8 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any The required minimum input size of the model is 17x17. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (SqueezeNet1_1_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet("1_1", pretrained, progress, **kwargs) + weights = SqueezeNet1_1_Weights.verify(weights) + return _squeezenet("1_1", weights, progress, **kwargs) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 07639017a31..c245eef6482 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,37 +1,37 @@ -from typing import Union, List, Dict, Any, cast +from functools import partial +from typing import Union, List, Dict, Any, Optional, cast import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "VGG", + "VGG11_Weights", + "VGG11_BN_Weights", + "VGG13_Weights", + "VGG13_BN_Weights", + "VGG16_Weights", + "VGG16_BN_Weights", + "VGG19_Weights", + "VGG19_BN_Weights", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", - "vgg19_bn", "vgg19", + "vgg19_bn", ] -model_urls = { - "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", - "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", - "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", - "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", -} - - class VGG(nn.Module): def __init__( self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 @@ -95,107 +95,276 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ } -def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: - if pretrained: +def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: + if weights is not None: kwargs["init_weights"] = False + if weights.meta["categories"] is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +_COMMON_META = { + "task": "image_classification", + "architecture": "VGG", + "publication_year": 2014, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", +} + + +class VGG11_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11-8a719046.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132863336, + "acc@1": 69.020, + "acc@5": 88.628, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG11_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132868840, + "acc@1": 70.370, + "acc@5": 89.810, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13-19584684.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133047848, + "acc@1": 69.928, + "acc@5": 89.246, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133053736, + "acc@1": 71.586, + "acc@5": 90.374, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16-397923af.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138357544, + "acc@1": 71.592, + "acc@5": 90.382, + }, + ) + # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the + # same input standardization method as the paper. Only the `features` weights have proper values, those on the + # `classifier` module are filled with nans. + IMAGENET1K_FEATURES = Weights( + url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", + transforms=partial( + ImageClassification, + crop_size=224, + mean=(0.48235, 0.45882, 0.40784), + std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), + ), + meta={ + **_COMMON_META, + "num_params": 138357544, + "categories": None, + "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", + "acc@1": float("nan"), + "acc@5": float("nan"), + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138365992, + "acc@1": 73.360, + "acc@5": 91.516, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143667240, + "acc@1": 72.376, + "acc@5": 90.876, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143678248, + "acc@1": 74.218, + "acc@5": 91.842, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) +def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 11-layer model (configuration "A") from `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG11_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg11", "A", False, pretrained, progress, **kwargs) + weights = VGG11_Weights.verify(weights) + return _vgg("A", False, weights, progress, **kwargs) -def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) +def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 11-layer model (configuration "A") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG11_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs) + weights = VGG11_BN_Weights.verify(weights) + + return _vgg("A", True, weights, progress, **kwargs) -def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) +def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 13-layer model (configuration "B") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG13_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg13", "B", False, pretrained, progress, **kwargs) + weights = VGG13_Weights.verify(weights) + return _vgg("B", False, weights, progress, **kwargs) -def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) +def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 13-layer model (configuration "B") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG13_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs) + weights = VGG13_BN_Weights.verify(weights) + + return _vgg("B", True, weights, progress, **kwargs) -def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) +def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 16-layer model (configuration "D") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg16", "D", False, pretrained, progress, **kwargs) + weights = VGG16_Weights.verify(weights) + return _vgg("D", False, weights, progress, **kwargs) -def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) +def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 16-layer model (configuration "D") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG16_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs) + weights = VGG16_BN_Weights.verify(weights) + + return _vgg("D", True, weights, progress, **kwargs) -def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) +def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 19-layer model (configuration "E") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG19_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg19", "E", False, pretrained, progress, **kwargs) + weights = VGG19_Weights.verify(weights) + return _vgg("E", False, weights, progress, **kwargs) -def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) +def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 19-layer model (configuration 'E') with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG19_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs) + weights = VGG19_BN_Weights.verify(weights) + + return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 4ac781a7c4c..618ddb96ba2 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,18 +1,25 @@ +from functools import partial from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union import torch.nn as nn from torch import Tensor -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms._presets import VideoClassification, InterpolationMode from ...utils import _log_api_usage_once +from .._api import WeightsEnum, Weights +from .._meta import _KINETICS400_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["r3d_18", "mc3_18", "r2plus1d_18"] -model_urls = { - "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", -} +__all__ = [ + "VideoResNet", + "R3D_18_Weights", + "MC3_18_Weights", + "R2Plus1D_18_Weights", + "r3d_18", + "mc3_18", + "r2plus1d_18", +] class Conv3DSimple(nn.Conv3d): @@ -281,80 +288,152 @@ def _make_layer( return nn.Sequential(*layers) -def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: - model = VideoResNet(**kwargs) +def _video_resnet( + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> VideoResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = VideoResNet(block, conv_makers, layers, stem, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) return model -def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +_COMMON_META = { + "task": "video_classification", + "publication_year": 2017, + "size": (112, 112), + "min_size": (1, 1), + "categories": _KINETICS400_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", +} + + +class R3D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "R3D", + "num_params": 33371472, + "acc@1": 52.75, + "acc@5": 75.45, + }, + ) + DEFAULT = KINETICS400_V1 + + +class MC3_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "MC3", + "num_params": 11695440, + "acc@1": 53.90, + "acc@5": 76.29, + }, + ) + DEFAULT = KINETICS400_V1 + + +class R2Plus1D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "R(2+1)D", + "num_params": 31505325, + "acc@1": 57.50, + "acc@5": 78.81, + }, + ) + DEFAULT = KINETICS400_V1 + + +@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) +def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Construct 18 layer Resnet3D model as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (R3D_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: R3D-18 network + VideoResNet: R3D-18 network """ + weights = R3D_18_Weights.verify(weights) return _video_resnet( - "r3d_18", - pretrained, + BasicBlock, + [Conv3DSimple] * 4, + [2, 2, 2, 2], + BasicStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] * 4, - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs, ) -def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) +def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for 18 layer Mixed Convolution network as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (MC3_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: MC3 Network definition + VideoResNet: MC3 Network definition """ + weights = MC3_18_Weights.verify(weights) + return _video_resnet( - "mc3_18", - pretrained, + BasicBlock, + [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] + [2, 2, 2, 2], + BasicStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs, ) -def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) +def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for the 18 layer deep R(2+1)D network as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: R(2+1)D-18 network + VideoResNet: R(2+1)D-18 network """ + weights = R2Plus1D_18_Weights.verify(weights) + return _video_resnet( - "r2plus1d_18", - pretrained, + BasicBlock, + [Conv2Plus1D] * 4, + [2, 2, 2, 2], + R2Plus1dStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv2Plus1D] * 4, - layers=[2, 2, 2, 2], - stem=R2Plus1dStem, **kwargs, ) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 43e4d315cec..fb34cf3c8e1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -6,25 +6,26 @@ import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param + __all__ = [ "VisionTransformer", + "ViT_B_16_Weights", + "ViT_B_32_Weights", + "ViT_L_16_Weights", + "ViT_L_32_Weights", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", ] -model_urls = { - "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", - "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", -} - class ConvStemConfig(NamedTuple): out_channels: int @@ -274,18 +275,20 @@ def forward(self, x: torch.Tensor): def _vision_transformer( - arch: str, patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VisionTransformer: image_size = kwargs.pop("image_size", 224) + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = VisionTransformer( image_size=image_size, patch_size=patch_size, @@ -296,98 +299,180 @@ def _vision_transformer( **kwargs, ) - if pretrained: - if arch not in model_urls: - raise ValueError(f"No checkpoint is available for model type '{arch}'!") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +_COMMON_META = { + "task": "image_classification", + "architecture": "ViT", + "publication_year": 2020, + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class ViT_B_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 86567656, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", + "acc@1": 81.072, + "acc@5": 95.318, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_B_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88224232, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", + "acc@1": 75.912, + "acc@5": 92.466, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=242), + meta={ + **_COMMON_META, + "num_params": 304326632, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", + "acc@1": 79.662, + "acc@5": 94.638, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 306535400, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", + "acc@1": 76.972, + "acc@5": 93.07, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) +def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_B_16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_B_16_Weights.verify(weights) + return _vision_transformer( - arch="vit_b_16", patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) +def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_B_32_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_B_32_Weights.verify(weights) + return _vision_transformer( - arch="vit_b_32", patch_size=32, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) +def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_L_16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_L_16_Weights.verify(weights) + return _vision_transformer( - arch="vit_l_16", patch_size=16, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) +def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_L_32_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_L_32_Weights.verify(weights) + return _vision_transformer( - arch="vit_l_32", patch_size=32, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index e1be6c81f59..bd35d31dcfd 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1,5 +1,4 @@ from . import datasets from . import features -from . import models from . import transforms from . import utils diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py deleted file mode 100644 index 83e49908348..00000000000 --- a/torchvision/prototype/models/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from .alexnet import * -from .convnext import * -from .densenet import * -from .efficientnet import * -from .googlenet import * -from .inception import * -from .mnasnet import * -from .mobilenet import * -from .regnet import * -from .resnet import * -from .shufflenetv2 import * -from .squeezenet import * -from .vgg import * -from .vision_transformer import * -from . import detection -from . import optical_flow -from . import quantization -from . import segmentation -from . import video -from ._api import get_weight diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py deleted file mode 100644 index cc9f7dcfc36..00000000000 --- a/torchvision/prototype/models/_utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import functools -import warnings -from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union - -from torch import nn -from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw - -from ._api import WeightsEnum - -W = TypeVar("W", bound=WeightsEnum) -M = TypeVar("M", bound=nn.Module) -V = TypeVar("V") - - -def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): - """Decorates a model builder with the new interface to make it compatible with the old. - - In particular this handles two things: - - 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See - :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. - 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to - ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. - - Args: - **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter - name and default value for the legacy ``pretrained=True``. The default value can be a callable in which - case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in - the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters - should be accessed with :meth:`~dict.get`. - """ - - def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: - @kwonly_to_pos_or_kw - @functools.wraps(builder) - def inner_wrapper(*args: Any, **kwargs: Any) -> M: - for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] - # If neither the weights nor the pretrained parameter as passed, or the weights argument already use - # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the - # weight argument, since it is a valid value. - sentinel = object() - weights_arg = kwargs.get(weights_param, sentinel) - if ( - (weights_param not in kwargs and pretrained_param not in kwargs) - or isinstance(weights_arg, WeightsEnum) - or (isinstance(weights_arg, str) and weights_arg != "legacy") - or weights_arg is None - ): - continue - - # If the pretrained parameter was passed as positional argument, it is now mapped to - # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current - # signature to infer the names of positionally passed arguments and thus has no knowledge that there - # used to be a pretrained parameter. - pretrained_positional = weights_arg is not sentinel - if pretrained_positional: - # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a - # unified access to the value if the default value is a callable. - kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) - else: - pretrained_arg = kwargs[pretrained_param] - - if pretrained_arg: - default_weights_arg = default(kwargs) if callable(default) else default - if not isinstance(default_weights_arg, WeightsEnum): - raise ValueError(f"No weights available for model {builder.__name__}") - else: - default_weights_arg = None - - if not pretrained_positional: - warnings.warn( - f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." - ) - - msg = ( - f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " - f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." - ) - if pretrained_arg: - msg = ( - f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " - f"to get the most up-to-date weights." - ) - warnings.warn(msg) - - del kwargs[pretrained_param] - kwargs[weights_param] = default_weights_arg - - return builder(*args, **kwargs) - - return inner_wrapper - - return outer_wrapper - - -def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: - if param in kwargs: - if kwargs[param] != new_value: - raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") - else: - kwargs[param] = new_value - - -def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: - if param is not None: - if param != new_value: - raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") - return new_value diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py deleted file mode 100644 index 204a68236d3..00000000000 --- a/torchvision/prototype/models/alexnet.py +++ /dev/null @@ -1,49 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.alexnet import AlexNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] - - -class AlexNet_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "AlexNet", - "publication_year": 2012, - "num_params": 61100840, - "size": (224, 224), - "min_size": (63, 63), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", - "acc@1": 56.522, - "acc@5": 79.066, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) -def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: - weights = AlexNet_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = AlexNet(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py deleted file mode 100644 index 7d63ee155db..00000000000 --- a/torchvision/prototype/models/convnext.py +++ /dev/null @@ -1,169 +0,0 @@ -from functools import partial -from typing import Any, List, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.convnext import ConvNeXt, CNBlockConfig -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ConvNeXt", - "ConvNeXt_Tiny_Weights", - "ConvNeXt_Small_Weights", - "ConvNeXt_Base_Weights", - "ConvNeXt_Large_Weights", - "convnext_tiny", - "convnext_small", - "convnext_base", - "convnext_large", -] - - -def _convnext( - block_setting: List[CNBlockConfig], - stochastic_depth_prob: float, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> ConvNeXt: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ConvNeXt", - "publication_year": 2022, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", -} - - -class ConvNeXt_Tiny_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), - meta={ - **_COMMON_META, - "num_params": 28589128, - "acc@1": 82.520, - "acc@5": 96.146, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Small_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), - meta={ - **_COMMON_META, - "num_params": 50223688, - "acc@1": 83.616, - "acc@5": 96.650, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Base_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 88591464, - "acc@1": 84.062, - "acc@5": 96.870, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Large_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 197767336, - "acc@1": 84.414, - "acc@5": 96.976, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) -def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - weights = ConvNeXt_Tiny_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(96, 192, 3), - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 9), - CNBlockConfig(768, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) -def convnext_small( - *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any -) -> ConvNeXt: - weights = ConvNeXt_Small_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(96, 192, 3), - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 27), - CNBlockConfig(768, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) -def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - weights = ConvNeXt_Base_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(128, 256, 3), - CNBlockConfig(256, 512, 3), - CNBlockConfig(512, 1024, 27), - CNBlockConfig(1024, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) -def convnext_large( - *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any -) -> ConvNeXt: - weights = ConvNeXt_Large_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 3), - CNBlockConfig(768, 1536, 27), - CNBlockConfig(1536, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py deleted file mode 100644 index 4ad9be028e5..00000000000 --- a/torchvision/prototype/models/densenet.py +++ /dev/null @@ -1,159 +0,0 @@ -import re -from functools import partial -from typing import Any, Optional, Tuple - -import torch.nn as nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.densenet import DenseNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "DenseNet", - "DenseNet121_Weights", - "DenseNet161_Weights", - "DenseNet169_Weights", - "DenseNet201_Weights", - "densenet121", - "densenet161", - "densenet169", - "densenet201", -] - - -def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: - # '.'s are no longer allowed in module names, but previous _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" - ) - - state_dict = weights.get_state_dict(progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - - -def _densenet( - growth_rate: int, - block_config: Tuple[int, int, int, int], - num_init_features: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> DenseNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) - - if weights is not None: - _load_state_dict(model=model, weights=weights, progress=progress) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "DenseNet", - "publication_year": 2016, - "size": (224, 224), - "min_size": (29, 29), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/pull/116", -} - - -class DenseNet121_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 7978856, - "acc@1": 74.434, - "acc@5": 91.972, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet161_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 28681000, - "acc@1": 77.138, - "acc@5": 93.560, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet169_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 14149480, - "acc@1": 75.600, - "acc@5": 92.806, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet201_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 20013928, - "acc@1": 76.896, - "acc@5": 93.370, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) -def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet121_Weights.verify(weights) - - return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) -def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet161_Weights.verify(weights) - - return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) -def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet169_Weights.verify(weights) - - return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) -def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet201_Weights.verify(weights) - - return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py deleted file mode 100644 index 4146651c737..00000000000 --- a/torchvision/prototype/models/detection/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .faster_rcnn import * -from .fcos import * -from .keypoint_rcnn import * -from .mask_rcnn import * -from .retinanet import * -from .ssd import * -from .ssdlite import * diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py deleted file mode 100644 index ecdd9bdb423..00000000000 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ /dev/null @@ -1,228 +0,0 @@ -from typing import Any, Optional, Union - -from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.faster_rcnn import ( - _mobilenet_extractor, - _resnet_fpn_extractor, - _validate_trainable_layers, - AnchorGenerator, - FasterRCNN, - misc_nn_ops, - overwrite_eps, -) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import ResNet50_Weights, resnet50 - - -__all__ = [ - "FasterRCNN", - "FasterRCNN_ResNet50_FPN_Weights", - "FasterRCNN_MobileNet_V3_Large_FPN_Weights", - "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", - "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", -] - - -_COMMON_META = { - "task": "image_object_detection", - "architecture": "FasterRCNN", - "publication_year": 2015, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 41755286, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", - "map": 37.0, - }, - ) - DEFAULT = COCO_V1 - - -class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 19386354, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", - "map": 32.8, - }, - ) - DEFAULT = COCO_V1 - - -class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 19386354, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", - "map": 22.8, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fasterrcnn_resnet50_fpn( - *, - weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model - - -def _fasterrcnn_mobilenet_v3_large_fpn( - *, - weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], - progress: bool, - num_classes: Optional[int], - weights_backbone: Optional[MobileNet_V3_Large_Weights], - trainable_backbone_layers: Optional[int], - **kwargs: Any, -) -> FasterRCNN: - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) - anchor_sizes = ( - ( - 32, - 64, - 128, - 256, - 512, - ), - ) * 3 - aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN( - backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def fasterrcnn_mobilenet_v3_large_fpn( - *, - weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - defaults = { - "rpn_score_thresh": 0.05, - } - - kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn( - weights=weights, - progress=progress, - num_classes=num_classes, - weights_backbone=weights_backbone, - trainable_backbone_layers=trainable_backbone_layers, - **kwargs, - ) - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def fasterrcnn_mobilenet_v3_large_320_fpn( - *, - weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - - weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - defaults = { - "min_size": 320, - "max_size": 640, - "rpn_pre_nms_top_n_test": 150, - "rpn_post_nms_top_n_test": 150, - "rpn_score_thresh": 0.05, - } - - kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn( - weights=weights, - progress=progress, - num_classes=num_classes, - weights_backbone=weights_backbone, - trainable_backbone_layers=trainable_backbone_layers, - **kwargs, - ) diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py deleted file mode 100644 index db3a679a62d..00000000000 --- a/torchvision/prototype/models/detection/fcos.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.fcos import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - FCOS, - LastLevelP6P7, - misc_nn_ops, -) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 - - -__all__ = [ - "FCOS", - "FCOS_ResNet50_FPN_Weights", - "fcos_resnet50_fpn", -] - - -class FCOS_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "FCOS", - "publication_year": 2019, - "num_params": 32269600, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", - "map": 39.2, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fcos_resnet50_fpn( - *, - weights: Optional[FCOS_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FCOS: - weights = FCOS_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) - ) - model = FCOS(backbone, num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py deleted file mode 100644 index e0b4d7061fa..00000000000 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.keypoint_rcnn import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - KeypointRCNN, - misc_nn_ops, - overwrite_eps, -) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 - - -__all__ = [ - "KeypointRCNN", - "KeypointRCNN_ResNet50_FPN_Weights", - "keypointrcnn_resnet50_fpn", -] - - -_COMMON_META = { - "task": "image_object_detection", - "architecture": "KeypointRCNN", - "publication_year": 2017, - "categories": _COCO_PERSON_CATEGORIES, - "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_LEGACY = Weights( - url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 59137258, - "recipe": "https://github.com/pytorch/vision/issues/1606", - "map": 50.6, - "map_kp": 61.1, - }, - ) - COCO_V1 = Weights( - url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 59137258, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", - "map": 54.6, - "map_kp": 65.0, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY - if kwargs["pretrained"] == "legacy" - else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1, - ), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def keypointrcnn_resnet50_fpn( - *, - weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - num_keypoints: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> KeypointRCNN: - weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) - else: - if num_classes is None: - num_classes = 2 - if num_keypoints is None: - num_keypoints = 17 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py deleted file mode 100644 index 187bf6912b4..00000000000 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.mask_rcnn import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - MaskRCNN, - misc_nn_ops, - overwrite_eps, -) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 - - -__all__ = [ - "MaskRCNN", - "MaskRCNN_ResNet50_FPN_Weights", - "maskrcnn_resnet50_fpn", -] - - -class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "MaskRCNN", - "publication_year": 2017, - "num_params": 44401393, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", - "map": 37.9, - "map_mask": 34.6, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def maskrcnn_resnet50_fpn( - *, - weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> MaskRCNN: - weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py deleted file mode 100644 index eadd6c635ca..00000000000 --- a/torchvision/prototype/models/detection/retinanet.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.retinanet import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - RetinaNet, - LastLevelP6P7, - misc_nn_ops, - overwrite_eps, -) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 - - -__all__ = [ - "RetinaNet", - "RetinaNet_ResNet50_FPN_Weights", - "retinanet_resnet50_fpn", -] - - -class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "RetinaNet", - "publication_year": 2017, - "num_params": 34014999, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", - "map": 36.4, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def retinanet_resnet50_fpn( - *, - weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> RetinaNet: - weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - # skip P2 because it generates too many anchors (according to their paper) - backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) - ) - model = RetinaNet(backbone, num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py deleted file mode 100644 index 3cab044958d..00000000000 --- a/torchvision/prototype/models/detection/ssd.py +++ /dev/null @@ -1,93 +0,0 @@ -import warnings -from typing import Any, Optional - -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.ssd import ( - _validate_trainable_layers, - _vgg_extractor, - DefaultBoxGenerator, - SSD, -) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..vgg import VGG16_Weights, vgg16 - - -__all__ = [ - "SSD300_VGG16_Weights", - "ssd300_vgg16", -] - - -class SSD300_VGG16_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "SSD", - "publication_year": 2015, - "num_params": 35641826, - "size": (300, 300), - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", - "map": 25.1, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), -) -def ssd300_vgg16( - *, - weights: Optional[SSD300_VGG16_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[VGG16_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> SSD: - weights = SSD300_VGG16_Weights.verify(weights) - weights_backbone = VGG16_Weights.verify(weights_backbone) - - if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the parameter.") - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 - ) - - # Use custom backbones more appropriate for SSD - backbone = vgg16(weights=weights_backbone, progress=progress) - backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator( - [[2], [2, 3], [2, 3], [2, 3], [2], [2]], - scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], - steps=[8, 16, 32, 64, 100, 300], - ) - - defaults = { - # Rescale the input in a way compatible to the backbone - "image_mean": [0.48235, 0.45882, 0.40784], - "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor - } - kwargs: Any = {**defaults, **kwargs} - model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py deleted file mode 100644 index f69e860dff4..00000000000 --- a/torchvision/prototype/models/detection/ssdlite.py +++ /dev/null @@ -1,129 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Callable, Optional - -from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.ssdlite import ( - _mobilenet_extractor, - _normal_init, - _validate_trainable_layers, - DefaultBoxGenerator, - det_utils, - SSD, - SSDLiteHead, -) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large - - -__all__ = [ - "SSDLite320_MobileNet_V3_Large_Weights", - "ssdlite320_mobilenet_v3_large", -] - - -class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "SSDLite", - "publication_year": 2018, - "num_params": 3440060, - "size": (320, 320), - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", - "map": 21.3, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def ssdlite320_mobilenet_v3_large( - *, - weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any, -) -> SSD: - weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the parameter.") - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 - ) - - # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper. - reduce_tail = weights_backbone is None - - if norm_layer is None: - norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - - backbone = mobilenet_v3_large( - weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs - ) - if weights_backbone is None: - # Change the default initialization scheme if not pretrained - _normal_init(backbone) - backbone = _mobilenet_extractor( - backbone, - trainable_backbone_layers, - norm_layer, - ) - - size = (320, 320) - anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) - out_channels = det_utils.retrieve_out_channels(backbone, size) - num_anchors = anchor_generator.num_anchors_per_location() - if len(out_channels) != len(anchor_generator.aspect_ratios): - raise ValueError( - f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}" - ) - - defaults = { - "score_thresh": 0.001, - "nms_thresh": 0.55, - "detections_per_img": 300, - "topk_candidates": 300, - # Rescale the input in a way compatible to the backbone: - # The following mean/std rescale the data from [0, 1] to [-1, -1] - "image_mean": [0.5, 0.5, 0.5], - "image_std": [0.5, 0.5, 0.5], - } - kwargs: Any = {**defaults, **kwargs} - model = SSD( - backbone, - anchor_generator, - size, - num_classes, - head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py deleted file mode 100644 index cb6d2bb2b35..00000000000 --- a/torchvision/prototype/models/efficientnet.py +++ /dev/null @@ -1,453 +0,0 @@ -from functools import partial -from typing import Any, Optional, Sequence, Union - -from torch import nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "EfficientNet", - "EfficientNet_B0_Weights", - "EfficientNet_B1_Weights", - "EfficientNet_B2_Weights", - "EfficientNet_B3_Weights", - "EfficientNet_B4_Weights", - "EfficientNet_B5_Weights", - "EfficientNet_B6_Weights", - "EfficientNet_B7_Weights", - "EfficientNet_V2_S_Weights", - "EfficientNet_V2_M_Weights", - "EfficientNet_V2_L_Weights", - "efficientnet_b0", - "efficientnet_b1", - "efficientnet_b2", - "efficientnet_b3", - "efficientnet_b4", - "efficientnet_b5", - "efficientnet_b6", - "efficientnet_b7", - "efficientnet_v2_s", - "efficientnet_v2_m", - "efficientnet_v2_l", -] - - -def _efficientnet( - inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], - dropout: float, - last_channel: Optional[int], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> EfficientNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "categories": _IMAGENET_CATEGORIES, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", -} - - -_COMMON_META_V1 = { - **_COMMON_META, - "architecture": "EfficientNet", - "publication_year": 2019, - "interpolation": InterpolationMode.BICUBIC, - "min_size": (1, 1), -} - - -_COMMON_META_V2 = { - **_COMMON_META, - "architecture": "EfficientNetV2", - "publication_year": 2021, - "interpolation": InterpolationMode.BILINEAR, - "min_size": (33, 33), -} - - -class EfficientNet_B0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - transforms=partial( - ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 5288548, - "size": (224, 224), - "acc@1": 77.692, - "acc@5": 93.532, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B1_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 7794184, - "size": (240, 240), - "acc@1": 78.642, - "acc@5": 94.186, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", - transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR - ), - meta={ - **_COMMON_META_V1, - "num_params": 7794184, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", - "interpolation": InterpolationMode.BILINEAR, - "size": (240, 240), - "acc@1": 79.838, - "acc@5": 94.934, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class EfficientNet_B2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - transforms=partial( - ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 9109994, - "size": (288, 288), - "acc@1": 80.608, - "acc@5": 95.310, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B3_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - transforms=partial( - ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 12233232, - "size": (300, 300), - "acc@1": 82.008, - "acc@5": 96.054, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B4_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - transforms=partial( - ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 19341616, - "size": (380, 380), - "acc@1": 83.384, - "acc@5": 96.594, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - transforms=partial( - ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 30389784, - "size": (456, 456), - "acc@1": 83.444, - "acc@5": 96.628, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B6_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - transforms=partial( - ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 43040704, - "size": (528, 528), - "acc@1": 84.008, - "acc@5": 96.916, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B7_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - transforms=partial( - ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 66347960, - "size": (600, 600), - "acc@1": 84.122, - "acc@5": 96.908, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_S_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", - transforms=partial( - ImageClassificationEval, - crop_size=384, - resize_size=384, - interpolation=InterpolationMode.BILINEAR, - ), - meta={ - **_COMMON_META_V2, - "num_params": 21458488, - "size": (384, 384), - "acc@1": 84.228, - "acc@5": 96.878, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_M_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", - transforms=partial( - ImageClassificationEval, - crop_size=480, - resize_size=480, - interpolation=InterpolationMode.BILINEAR, - ), - meta={ - **_COMMON_META_V2, - "num_params": 54139356, - "size": (480, 480), - "acc@1": 85.112, - "acc@5": 97.156, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_L_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", - transforms=partial( - ImageClassificationEval, - crop_size=480, - resize_size=480, - interpolation=InterpolationMode.BICUBIC, - mean=(0.5, 0.5, 0.5), - std=(0.5, 0.5, 0.5), - ), - meta={ - **_COMMON_META_V2, - "num_params": 118515272, - "size": (480, 480), - "acc@1": 85.808, - "acc@5": 97.788, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) -def efficientnet_b0( - *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B0_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) - return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) -def efficientnet_b1( - *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B1_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) - return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) -def efficientnet_b2( - *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B2_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) - return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) -def efficientnet_b3( - *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B3_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) - return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) -def efficientnet_b4( - *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B4_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) - return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) -def efficientnet_b5( - *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B5_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) - return _efficientnet( - inverted_residual_setting, - 0.4, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) -def efficientnet_b6( - *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B6_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) - return _efficientnet( - inverted_residual_setting, - 0.5, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) -def efficientnet_b7( - *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B7_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) - return _efficientnet( - inverted_residual_setting, - 0.5, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) -def efficientnet_v2_s( - *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_S_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") - return _efficientnet( - inverted_residual_setting, - 0.2, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) -def efficientnet_v2_m( - *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_M_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") - return _efficientnet( - inverted_residual_setting, - 0.3, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) -def efficientnet_v2_l( - *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_L_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") - return _efficientnet( - inverted_residual_setting, - 0.4, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py deleted file mode 100644 index 70dc0d9db5c..00000000000 --- a/torchvision/prototype/models/googlenet.py +++ /dev/null @@ -1,63 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] - - -class GoogLeNet_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "GoogLeNet", - "publication_year": 2014, - "num_params": 6624904, - "size": (224, 224), - "min_size": (15, 15), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", - "acc@1": 69.778, - "acc@5": 89.530, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) -def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: - weights = GoogLeNet_Weights.verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = GoogLeNet(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.aux1 = None # type: ignore[assignment] - model.aux2 = None # type: ignore[assignment] - else: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - - return model diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py deleted file mode 100644 index eec78a26236..00000000000 --- a/torchvision/prototype/models/inception.py +++ /dev/null @@ -1,57 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] - - -class Inception_V3_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), - meta={ - "task": "image_classification", - "architecture": "InceptionV3", - "publication_year": 2015, - "num_params": 27161264, - "size": (299, 299), - "min_size": (75, 75), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", - "acc@1": 77.294, - "acc@5": 93.450, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) -def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: - weights = Inception_V3_Weights.verify(weights) - - original_aux_logits = kwargs.get("aux_logits", True) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = Inception3(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - - return model diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py deleted file mode 100644 index c48e34a7be5..00000000000 --- a/torchvision/prototype/models/mnasnet.py +++ /dev/null @@ -1,113 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mnasnet import MNASNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "MNASNet", - "MNASNet0_5_Weights", - "MNASNet0_75_Weights", - "MNASNet1_0_Weights", - "MNASNet1_3_Weights", - "mnasnet0_5", - "mnasnet0_75", - "mnasnet1_0", - "mnasnet1_3", -] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MNASNet", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/1e100/mnasnet_trainer", -} - - -class MNASNet0_5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2218512, - "acc@1": 67.734, - "acc@5": 87.490, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MNASNet0_75_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet0_75 - pass - - -class MNASNet1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 4383312, - "acc@1": 73.456, - "acc@5": 91.510, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MNASNet1_3_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet1_3 - pass - - -def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MNASNet(alpha, **kwargs) - - if weights: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) -def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet0_5_Weights.verify(weights) - - return _mnasnet(0.5, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet0_75_Weights.verify(weights) - - return _mnasnet(0.75, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) -def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet1_0_Weights.verify(weights) - - return _mnasnet(1.0, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet1_3_Weights.verify(weights) - - return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mobilenet.py b/torchvision/prototype/models/mobilenet.py deleted file mode 100644 index 0a270d14d3a..00000000000 --- a/torchvision/prototype/models/mobilenet.py +++ /dev/null @@ -1,6 +0,0 @@ -from .mobilenetv2 import * # noqa: F401, F403 -from .mobilenetv3 import * # noqa: F401, F403 -from .mobilenetv2 import __all__ as mv2_all -from .mobilenetv3 import __all__ as mv3_all - -__all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py deleted file mode 100644 index 71b412898fe..00000000000 --- a/torchvision/prototype/models/mobilenetv2.py +++ /dev/null @@ -1,66 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mobilenetv2 import MobileNetV2 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MobileNetV2", - "publication_year": 2018, - "num_params": 3504872, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class MobileNet_V2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", - "acc@1": 71.878, - "acc@5": 90.286, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", - "acc@1": 72.154, - "acc@5": 90.822, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) -def mobilenet_v2( - *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV2: - weights = MobileNet_V2_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MobileNetV2(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py deleted file mode 100644 index aaf9c2c85a4..00000000000 --- a/torchvision/prototype/models/mobilenetv3.py +++ /dev/null @@ -1,109 +0,0 @@ -from functools import partial -from typing import Any, Optional, List - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "MobileNetV3", - "MobileNet_V3_Large_Weights", - "MobileNet_V3_Small_Weights", - "mobilenet_v3_large", - "mobilenet_v3_small", -] - - -def _mobilenet_v3( - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> MobileNetV3: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MobileNetV3", - "publication_year": 2019, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class MobileNet_V3_Large_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 5483032, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", - "acc@1": 74.042, - "acc@5": 91.340, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 5483032, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", - "acc@1": 75.274, - "acc@5": 92.566, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class MobileNet_V3_Small_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2542856, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", - "acc@1": 67.668, - "acc@5": 87.402, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) -def mobilenet_v3_large( - *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV3: - weights = MobileNet_V3_Large_Weights.verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) - return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) -def mobilenet_v3_small( - *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV3: - weights = MobileNet_V3_Small_Weights.verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) - return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/optical_flow/__init__.py b/torchvision/prototype/models/optical_flow/__init__.py deleted file mode 100644 index 9b78f70b768..00000000000 --- a/torchvision/prototype/models/optical_flow/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py deleted file mode 100644 index 24e87f3d4f9..00000000000 --- a/torchvision/prototype/models/optical_flow/raft.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Optional - -from torch.nn.modules.batchnorm import BatchNorm2d -from torch.nn.modules.instancenorm import InstanceNorm2d -from torchvision.models.optical_flow import RAFT -from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock -from torchvision.prototype.transforms import OpticalFlowEval -from torchvision.transforms.functional import InterpolationMode - -from .._api import WeightsEnum -from .._api import Weights -from .._utils import handle_legacy_interface - - -__all__ = ( - "RAFT", - "raft_large", - "raft_small", - "Raft_Large_Weights", - "Raft_Small_Weights", -) - - -_COMMON_META = { - "task": "optical_flow", - "architecture": "RAFT", - "publication_year": 2020, - "interpolation": InterpolationMode.BILINEAR, -} - - -class Raft_Large_Weights(WeightsEnum): - C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-things.pth) - url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_train_cleanpass_epe": 1.4411, - "sintel_train_finalpass_epe": 2.7894, - "kitti_train_per_image_epe": 5.0172, - "kitti_train_f1-all": 17.4506, - }, - ) - - C_T_V2 = Weights( - # Chairs + Things - url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_train_cleanpass_epe": 1.3822, - "sintel_train_finalpass_epe": 2.7161, - "kitti_train_per_image_epe": 4.5118, - "kitti_train_f1-all": 16.0679, - }, - ) - - C_T_SKHT_V1 = Weights( - # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_test_cleanpass_epe": 1.94, - "sintel_test_finalpass_epe": 3.18, - }, - ) - - C_T_SKHT_V2 = Weights( - # Chairs + Things + Sintel fine-tuning, i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_test_cleanpass_epe": 1.819, - "sintel_test_finalpass_epe": 3.067, - }, - ) - - C_T_SKHT_K_V1 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "kitti_test_f1-all": 5.10, - }, - ) - - C_T_SKHT_K_V2 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti - # Same as CT_SKHT with extra fine-tuning on Kitti - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "kitti_test_f1-all": 5.19, - }, - ) - - DEFAULT = C_T_SKHT_V2 - - -class Raft_Small_Weights(WeightsEnum): - C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-small.pth) - url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 990162, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_train_cleanpass_epe": 2.1231, - "sintel_train_finalpass_epe": 3.2790, - "kitti_train_per_image_epe": 7.6557, - "kitti_train_f1-all": 25.2801, - }, - ) - C_T_V2 = Weights( - # Chairs + Things - url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 990162, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_train_cleanpass_epe": 1.9901, - "sintel_train_finalpass_epe": 3.2831, - "kitti_train_per_image_epe": 7.5978, - "kitti_train_f1-all": 25.2369, - }, - ) - - DEFAULT = C_T_V2 - - -@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) -def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): - """RAFT model from - `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. - - Args: - weights(Raft_Large_weights, optional): pretrained weights to use. - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - nn.Module: The model. - """ - - weights = Raft_Large_Weights.verify(weights) - - model = _raft( - # Feature encoder - feature_encoder_layers=(64, 64, 96, 128, 256), - feature_encoder_block=ResidualBlock, - feature_encoder_norm_layer=InstanceNorm2d, - # Context encoder - context_encoder_layers=(64, 64, 96, 128, 256), - context_encoder_block=ResidualBlock, - context_encoder_norm_layer=BatchNorm2d, - # Correlation block - corr_block_num_levels=4, - corr_block_radius=4, - # Motion encoder - motion_encoder_corr_layers=(256, 192), - motion_encoder_flow_layers=(128, 64), - motion_encoder_out_channels=128, - # Recurrent block - recurrent_block_hidden_state_size=128, - recurrent_block_kernel_size=((1, 5), (5, 1)), - recurrent_block_padding=((0, 2), (2, 0)), - # Flow head - flow_head_hidden_size=256, - # Mask predictor - use_mask_predictor=True, - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) -def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): - """RAFT "small" model from - `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. - - Args: - weights(Raft_Small_weights, optional): pretrained weights to use. - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - nn.Module: The model. - - """ - - weights = Raft_Small_Weights.verify(weights) - - model = _raft( - # Feature encoder - feature_encoder_layers=(32, 32, 64, 96, 128), - feature_encoder_block=BottleneckBlock, - feature_encoder_norm_layer=InstanceNorm2d, - # Context encoder - context_encoder_layers=(32, 32, 64, 96, 160), - context_encoder_block=BottleneckBlock, - context_encoder_norm_layer=None, - # Correlation block - corr_block_num_levels=4, - corr_block_radius=3, - # Motion encoder - motion_encoder_corr_layers=(96,), - motion_encoder_flow_layers=(64, 32), - motion_encoder_out_channels=82, - # Recurrent block - recurrent_block_hidden_state_size=96, - recurrent_block_kernel_size=(3,), - recurrent_block_padding=(1,), - # Flow head - flow_head_hidden_size=128, - # Mask predictor - use_mask_predictor=False, - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - return model diff --git a/torchvision/prototype/models/quantization/__init__.py b/torchvision/prototype/models/quantization/__init__.py deleted file mode 100644 index da8bbba3567..00000000000 --- a/torchvision/prototype/models/quantization/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .googlenet import * -from .inception import * -from .mobilenet import * -from .resnet import * -from .shufflenetv2 import * diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py deleted file mode 100644 index cca6ba25060..00000000000 --- a/torchvision/prototype/models/quantization/googlenet.py +++ /dev/null @@ -1,94 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.googlenet import ( - QuantizableGoogLeNet, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..googlenet import GoogLeNet_Weights - - -__all__ = [ - "QuantizableGoogLeNet", - "GoogLeNet_QuantizedWeights", - "googlenet", -] - - -class GoogLeNet_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "GoogLeNet", - "publication_year": 2014, - "num_params": 6624904, - "size": (224, 224), - "min_size": (15, 15), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, - "acc@1": 69.826, - "acc@5": 89.404, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else GoogLeNet_Weights.IMAGENET1K_V1, - ) -) -def googlenet( - *, - weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableGoogLeNet: - weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableGoogLeNet(**kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.aux1 = None # type: ignore[assignment] - model.aux2 = None # type: ignore[assignment] - else: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - - return model diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py deleted file mode 100644 index 2639b7de14f..00000000000 --- a/torchvision/prototype/models/quantization/inception.py +++ /dev/null @@ -1,90 +0,0 @@ -from functools import partial -from typing import Any, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.inception import ( - QuantizableInception3, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..inception import Inception_V3_Weights - - -__all__ = [ - "QuantizableInception3", - "Inception_V3_QuantizedWeights", - "inception_v3", -] - - -class Inception_V3_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), - meta={ - "task": "image_classification", - "architecture": "InceptionV3", - "publication_year": 2015, - "num_params": 27161264, - "size": (299, 299), - "min_size": (75, 75), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": Inception_V3_Weights.IMAGENET1K_V1, - "acc@1": 77.176, - "acc@5": 93.354, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else Inception_V3_Weights.IMAGENET1K_V1, - ) -) -def inception_v3( - *, - weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableInception3: - weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableInception3(**kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - if quantize and not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not quantize and not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - - return model diff --git a/torchvision/prototype/models/quantization/mobilenet.py b/torchvision/prototype/models/quantization/mobilenet.py deleted file mode 100644 index 0a270d14d3a..00000000000 --- a/torchvision/prototype/models/quantization/mobilenet.py +++ /dev/null @@ -1,6 +0,0 @@ -from .mobilenetv2 import * # noqa: F401, F403 -from .mobilenetv3 import * # noqa: F401, F403 -from .mobilenetv2 import __all__ as mv2_all -from .mobilenetv3 import __all__ as mv3_all - -__all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py deleted file mode 100644 index a9789583fe6..00000000000 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ /dev/null @@ -1,81 +0,0 @@ -from functools import partial -from typing import Any, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.mobilenetv2 import ( - QuantizableInvertedResidual, - QuantizableMobileNetV2, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..mobilenetv2 import MobileNet_V2_Weights - - -__all__ = [ - "QuantizableMobileNetV2", - "MobileNet_V2_QuantizedWeights", - "mobilenet_v2", -] - - -class MobileNet_V2_QuantizedWeights(WeightsEnum): - IMAGENET1K_QNNPACK_V1 = Weights( - url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "MobileNetV2", - "publication_year": 2018, - "num_params": 3504872, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "qnnpack", - "quantization": "qat", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", - "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, - "acc@1": 71.658, - "acc@5": 90.150, - }, - ) - DEFAULT = IMAGENET1K_QNNPACK_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 - if kwargs.get("quantize", False) - else MobileNet_V2_Weights.IMAGENET1K_V1, - ) -) -def mobilenet_v2( - *, - weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableMobileNetV2: - weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "qnnpack") - - model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py deleted file mode 100644 index 915308d948f..00000000000 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ /dev/null @@ -1,101 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Union - -import torch -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.mobilenetv3 import ( - InvertedResidualConfig, - QuantizableInvertedResidual, - QuantizableMobileNetV3, - _replace_relu, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf - - -__all__ = [ - "QuantizableMobileNetV3", - "MobileNet_V3_Large_QuantizedWeights", - "mobilenet_v3_large", -] - - -def _mobilenet_v3_model( - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableMobileNetV3: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "qnnpack") - - model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) - - if quantize: - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) - torch.ao.quantization.prepare_qat(model, inplace=True) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - if quantize: - torch.ao.quantization.convert(model, inplace=True) - model.eval() - - return model - - -class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): - IMAGENET1K_QNNPACK_V1 = Weights( - url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "MobileNetV3", - "publication_year": 2019, - "num_params": 5483032, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "qnnpack", - "quantization": "qat", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", - "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, - "acc@1": 73.004, - "acc@5": 90.858, - }, - ) - DEFAULT = IMAGENET1K_QNNPACK_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 - if kwargs.get("quantize", False) - else MobileNet_V3_Large_Weights.IMAGENET1K_V1, - ) -) -def mobilenet_v3_large( - *, - weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableMobileNetV3: - weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) - return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py deleted file mode 100644 index 9e2e29db0bf..00000000000 --- a/torchvision/prototype/models/quantization/resnet.py +++ /dev/null @@ -1,204 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Type, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.resnet import ( - QuantizableBasicBlock, - QuantizableBottleneck, - QuantizableResNet, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights - - -__all__ = [ - "QuantizableResNet", - "ResNet18_QuantizedWeights", - "ResNet50_QuantizedWeights", - "ResNeXt101_32X8D_QuantizedWeights", - "resnet18", - "resnet50", - "resnext101_32x8d", -] - - -def _resnet( - block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], - layers: List[int], - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableResNet(block, layers, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", -} - - -class ResNet18_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 11689512, - "unquantized": ResNet18_Weights.IMAGENET1K_V1, - "acc@1": 69.494, - "acc@5": 88.882, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -class ResNet50_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "unquantized": ResNet50_Weights.IMAGENET1K_V1, - "acc@1": 75.920, - "acc@5": 92.814, - }, - ) - IMAGENET1K_FBGEMM_V2 = Weights( - url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "unquantized": ResNet50_Weights.IMAGENET1K_V2, - "acc@1": 80.282, - "acc@5": 94.976, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V2 - - -class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, - "acc@1": 78.986, - "acc@5": 94.480, - }, - ) - IMAGENET1K_FBGEMM_V2 = Weights( - url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, - "acc@1": 82.574, - "acc@5": 96.132, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V2 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNet18_Weights.IMAGENET1K_V1, - ) -) -def resnet18( - *, - weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) - - return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNet50_Weights.IMAGENET1K_V1, - ) -) -def resnet50( - *, - weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) - - return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, - ) -) -def resnext101_32x8d( - *, - weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 8) - return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py deleted file mode 100644 index e21349ff8e0..00000000000 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ /dev/null @@ -1,136 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.shufflenetv2 import ( - QuantizableShuffleNetV2, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights - - -__all__ = [ - "QuantizableShuffleNetV2", - "ShuffleNet_V2_X0_5_QuantizedWeights", - "ShuffleNet_V2_X1_0_QuantizedWeights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", -] - - -def _shufflenetv2( - stages_repeats: List[int], - stages_out_channels: List[int], - *, - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ShuffleNetV2", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", -} - - -class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 1366792, - "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, - "acc@1": 57.972, - "acc@5": 79.780, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2278604, - "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, - "acc@1": 68.360, - "acc@5": 87.582, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, - ) -) -def shufflenet_v2_x0_5( - *, - weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) - return _shufflenetv2( - [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs - ) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, - ) -) -def shufflenet_v2_x1_0( - *, - weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) - return _shufflenetv2( - [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs - ) diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py deleted file mode 100644 index d5e2b535532..00000000000 --- a/torchvision/prototype/models/regnet.py +++ /dev/null @@ -1,575 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torch import nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.regnet import RegNet, BlockParams -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "RegNet", - "RegNet_Y_400MF_Weights", - "RegNet_Y_800MF_Weights", - "RegNet_Y_1_6GF_Weights", - "RegNet_Y_3_2GF_Weights", - "RegNet_Y_8GF_Weights", - "RegNet_Y_16GF_Weights", - "RegNet_Y_32GF_Weights", - "RegNet_Y_128GF_Weights", - "RegNet_X_400MF_Weights", - "RegNet_X_800MF_Weights", - "RegNet_X_1_6GF_Weights", - "RegNet_X_3_2GF_Weights", - "RegNet_X_8GF_Weights", - "RegNet_X_16GF_Weights", - "RegNet_X_32GF_Weights", - "regnet_y_400mf", - "regnet_y_800mf", - "regnet_y_1_6gf", - "regnet_y_3_2gf", - "regnet_y_8gf", - "regnet_y_16gf", - "regnet_y_32gf", - "regnet_y_128gf", - "regnet_x_400mf", - "regnet_x_800mf", - "regnet_x_1_6gf", - "regnet_x_3_2gf", - "regnet_x_8gf", - "regnet_x_16gf", - "regnet_x_32gf", -] - -_COMMON_META = { - "task": "image_classification", - "architecture": "RegNet", - "publication_year": 2020, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -def _regnet( - block_params: BlockParams, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> RegNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) - model = RegNet(block_params, norm_layer=norm_layer, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -class RegNet_Y_400MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 4344144, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 74.046, - "acc@5": 91.716, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 4344144, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 75.804, - "acc@5": 92.742, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_800MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 6432512, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 76.420, - "acc@5": 93.136, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 6432512, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 78.828, - "acc@5": 94.502, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_1_6GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 11202430, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 77.950, - "acc@5": 93.966, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 11202430, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 80.876, - "acc@5": 95.444, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_3_2GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 19436338, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 78.948, - "acc@5": 94.576, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 19436338, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.982, - "acc@5": 95.972, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_8GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 39381472, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 80.032, - "acc@5": 95.048, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 39381472, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.828, - "acc@5": 96.330, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_16GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 83590140, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.424, - "acc@5": 95.240, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 83590140, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.886, - "acc@5": 96.328, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_32GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 145046770, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.878, - "acc@5": 95.340, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 145046770, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 83.368, - "acc@5": 96.498, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_128GF_Weights(WeightsEnum): - # weights are not available yet. - pass - - -class RegNet_X_400MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 5495976, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 72.834, - "acc@5": 90.950, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 5495976, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 74.864, - "acc@5": 92.322, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_800MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 7259656, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 75.212, - "acc@5": 92.348, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 7259656, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 77.522, - "acc@5": 93.826, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_1_6GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 9190136, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 77.040, - "acc@5": 93.440, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 9190136, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 79.668, - "acc@5": 94.922, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_3_2GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 15296552, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 78.364, - "acc@5": 93.992, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 15296552, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.196, - "acc@5": 95.430, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_8GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 39572648, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 79.344, - "acc@5": 94.686, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 39572648, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.682, - "acc@5": 95.678, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_16GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 54278536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 80.058, - "acc@5": 94.944, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 54278536, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.716, - "acc@5": 96.196, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_32GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 107811560, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.622, - "acc@5": 95.248, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 107811560, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 83.014, - "acc@5": 96.288, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) -def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_400MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) -def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_800MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) -def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_1_6GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) -def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_3_2GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) -def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_8GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) -def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_16GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) -def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_32GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_128GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) -def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_400MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) -def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_800MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) -def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_1_6GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) -def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_3_2GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) -def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_8GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) -def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_16GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) -def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_32GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) - return _regnet(params, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py deleted file mode 100644 index 35e30c0e760..00000000000 --- a/torchvision/prototype/models/resnet.py +++ /dev/null @@ -1,381 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Type, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.resnet import BasicBlock, Bottleneck, ResNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ResNet", - "ResNet18_Weights", - "ResNet34_Weights", - "ResNet50_Weights", - "ResNet101_Weights", - "ResNet152_Weights", - "ResNeXt50_32X4D_Weights", - "ResNeXt101_32X8D_Weights", - "Wide_ResNet50_2_Weights", - "Wide_ResNet101_2_Weights", - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - "resnext50_32x4d", - "resnext101_32x8d", - "wide_resnet50_2", - "wide_resnet101_2", -] - - -def _resnet( - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> ResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ResNet(block, layers, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class ResNet18_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 11689512, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 69.758, - "acc@5": 89.078, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ResNet34_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 21797672, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 73.314, - "acc@5": 91.420, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ResNet50_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 76.130, - "acc@5": 92.862, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", - "acc@1": 80.858, - "acc@5": 95.434, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNet101_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 44549160, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 77.374, - "acc@5": 93.546, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 44549160, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.886, - "acc@5": 95.780, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNet152_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 60192808, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 78.312, - "acc@5": 94.046, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 60192808, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.284, - "acc@5": 96.002, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNeXt50_32X4D_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 25028904, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", - "acc@1": 77.618, - "acc@5": 93.698, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 25028904, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.198, - "acc@5": 95.340, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNeXt101_32X8D_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", - "acc@1": 79.312, - "acc@5": 94.526, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 82.834, - "acc@5": 96.228, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class Wide_ResNet50_2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 68883240, - "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", - "acc@1": 78.468, - "acc@5": 94.086, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 68883240, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 81.602, - "acc@5": 95.758, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class Wide_ResNet101_2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 126886696, - "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", - "acc@1": 78.848, - "acc@5": 94.284, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 126886696, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.510, - "acc@5": 96.020, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) -def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet18_Weights.verify(weights) - - return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) -def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet34_Weights.verify(weights) - - return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) -def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet50_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) -def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet101_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) -def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet152_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) -def resnext50_32x4d( - *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = ResNeXt50_32X4D_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 4) - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) -def resnext101_32x8d( - *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = ResNeXt101_32X8D_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 8) - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) -def wide_resnet50_2( - *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = Wide_ResNet50_2_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) -def wide_resnet101_2( - *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = Wide_ResNet101_2_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/prototype/models/segmentation/__init__.py b/torchvision/prototype/models/segmentation/__init__.py deleted file mode 100644 index 20273be2170..00000000000 --- a/torchvision/prototype/models/segmentation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .fcn import * -from .lraspp import * -from .deeplabv3 import * diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py deleted file mode 100644 index 7165078161f..00000000000 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ /dev/null @@ -1,174 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import resnet50, resnet101 -from ..resnet import ResNet50_Weights, ResNet101_Weights - - -__all__ = [ - "DeepLabV3", - "DeepLabV3_ResNet50_Weights", - "DeepLabV3_ResNet101_Weights", - "DeepLabV3_MobileNet_V3_Large_Weights", - "deeplabv3_mobilenet_v3_large", - "deeplabv3_resnet50", - "deeplabv3_resnet101", -] - - -_COMMON_META = { - "task": "image_semantic_segmentation", - "architecture": "DeepLabV3", - "publication_year": 2017, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class DeepLabV3_ResNet50_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 42004074, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", - "mIoU": 66.4, - "acc": 92.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class DeepLabV3_ResNet101_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 60996202, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", - "mIoU": 67.4, - "acc": 92.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 11029328, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", - "mIoU": 60.3, - "acc": 91.2, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def deeplabv3_resnet50( - *, - weights: Optional[DeepLabV3_ResNet50_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_ResNet50_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), -) -def deeplabv3_resnet101( - *, - weights: Optional[DeepLabV3_ResNet101_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_ResNet101_Weights.verify(weights) - weights_backbone = ResNet101_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def deeplabv3_mobilenet_v3_large( - *, - weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) - model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py deleted file mode 100644 index 1dfc251844f..00000000000 --- a/torchvision/prototype/models/segmentation/fcn.py +++ /dev/null @@ -1,117 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.fcn import FCN, _fcn_resnet -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 - - -__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] - - -_COMMON_META = { - "task": "image_semantic_segmentation", - "architecture": "FCN", - "publication_year": 2014, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class FCN_ResNet50_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 35322218, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", - "mIoU": 60.5, - "acc": 91.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class FCN_ResNet101_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 54314346, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", - "mIoU": 63.7, - "acc": 91.9, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fcn_resnet50( - *, - weights: Optional[FCN_ResNet50_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - **kwargs: Any, -) -> FCN: - weights = FCN_ResNet50_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _fcn_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), -) -def fcn_resnet101( - *, - weights: Optional[FCN_ResNet101_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, - **kwargs: Any, -) -> FCN: - weights = FCN_ResNet101_Weights.verify(weights) - weights_backbone = ResNet101_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _fcn_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py deleted file mode 100644 index 2c0fa6f0aff..00000000000 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ /dev/null @@ -1,66 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large - - -__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] - - -class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - "task": "image_semantic_segmentation", - "architecture": "LRASPP", - "publication_year": 2019, - "num_params": 3221538, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", - "mIoU": 57.9, - "acc": 91.2, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def lraspp_mobilenet_v3_large( - *, - weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - **kwargs: Any, -) -> LRASPP: - if kwargs.pop("aux_loss", False): - raise NotImplementedError("This model does not use auxiliary loss") - - weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 21 - - backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) - model = _lraspp_mobilenetv3(backbone, num_classes) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py deleted file mode 100644 index 48047a70c60..00000000000 --- a/torchvision/prototype/models/shufflenetv2.py +++ /dev/null @@ -1,124 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.shufflenetv2 import ShuffleNetV2 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ShuffleNetV2", - "ShuffleNet_V2_X0_5_Weights", - "ShuffleNet_V2_X1_0_Weights", - "ShuffleNet_V2_X1_5_Weights", - "ShuffleNet_V2_X2_0_Weights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", - "shufflenet_v2_x1_5", - "shufflenet_v2_x2_0", -] - - -def _shufflenetv2( - weights: Optional[WeightsEnum], - progress: bool, - *args: Any, - **kwargs: Any, -) -> ShuffleNetV2: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ShuffleNetV2(*args, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ShuffleNetV2", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0", -} - - -class ShuffleNet_V2_X0_5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 1366792, - "acc@1": 69.362, - "acc@5": 88.316, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ShuffleNet_V2_X1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2278604, - "acc@1": 60.552, - "acc@5": 81.746, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ShuffleNet_V2_X1_5_Weights(WeightsEnum): - pass - - -class ShuffleNet_V2_X2_0_Weights(WeightsEnum): - pass - - -@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) -def shufflenet_v2_x0_5( - *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X0_5_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) -def shufflenet_v2_x1_0( - *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X1_0_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def shufflenet_v2_x1_5( - *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X1_5_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def shufflenet_v2_x2_0( - *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X2_0_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py deleted file mode 100644 index 7f6a034ed6c..00000000000 --- a/torchvision/prototype/models/squeezenet.py +++ /dev/null @@ -1,88 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.squeezenet import SqueezeNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "SqueezeNet", - "publication_year": 2016, - "size": (224, 224), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", -} - - -class SqueezeNet1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "min_size": (21, 21), - "num_params": 1248424, - "acc@1": 58.092, - "acc@5": 80.420, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class SqueezeNet1_1_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "min_size": (17, 17), - "num_params": 1235496, - "acc@1": 58.178, - "acc@5": 80.624, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1)) -def squeezenet1_0( - *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> SqueezeNet: - weights = SqueezeNet1_0_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = SqueezeNet("1_0", **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1)) -def squeezenet1_1( - *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any -) -> SqueezeNet: - weights = SqueezeNet1_1_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = SqueezeNet("1_1", **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py deleted file mode 100644 index 233c35418ed..00000000000 --- a/torchvision/prototype/models/vgg.py +++ /dev/null @@ -1,240 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.vgg import VGG, make_layers, cfgs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "VGG", - "VGG11_Weights", - "VGG11_BN_Weights", - "VGG13_Weights", - "VGG13_BN_Weights", - "VGG16_Weights", - "VGG16_BN_Weights", - "VGG19_Weights", - "VGG19_BN_Weights", - "vgg11", - "vgg11_bn", - "vgg13", - "vgg13_bn", - "vgg16", - "vgg16_bn", - "vgg19", - "vgg19_bn", -] - - -def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "VGG", - "publication_year": 2014, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", -} - - -class VGG11_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 132863336, - "acc@1": 69.020, - "acc@5": 88.628, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG11_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 132868840, - "acc@1": 70.370, - "acc@5": 89.810, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG13_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 133047848, - "acc@1": 69.928, - "acc@5": 89.246, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG13_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 133053736, - "acc@1": 71.586, - "acc@5": 90.374, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 138357544, - "acc@1": 71.592, - "acc@5": 90.382, - }, - ) - # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the - # same input standardization method as the paper. Only the `features` weights have proper values, those on the - # `classifier` module are filled with nans. - IMAGENET1K_FEATURES = Weights( - url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", - transforms=partial( - ImageClassificationEval, - crop_size=224, - mean=(0.48235, 0.45882, 0.40784), - std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), - ), - meta={ - **_COMMON_META, - "num_params": 138357544, - "categories": None, - "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", - "acc@1": float("nan"), - "acc@5": float("nan"), - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG16_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 138365992, - "acc@1": 73.360, - "acc@5": 91.516, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG19_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 143667240, - "acc@1": 72.376, - "acc@5": 90.876, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG19_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 143678248, - "acc@1": 74.218, - "acc@5": 91.842, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) -def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG11_Weights.verify(weights) - - return _vgg("A", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) -def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG11_BN_Weights.verify(weights) - - return _vgg("A", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) -def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG13_Weights.verify(weights) - - return _vgg("B", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) -def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG13_BN_Weights.verify(weights) - - return _vgg("B", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) -def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG16_Weights.verify(weights) - - return _vgg("D", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) -def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG16_BN_Weights.verify(weights) - - return _vgg("D", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) -def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG19_Weights.verify(weights) - - return _vgg("E", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) -def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG19_BN_Weights.verify(weights) - - return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/video/__init__.py b/torchvision/prototype/models/video/__init__.py deleted file mode 100644 index b792ca6ecf7..00000000000 --- a/torchvision/prototype/models/video/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .resnet import * diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py deleted file mode 100644 index 790d254d266..00000000000 --- a/torchvision/prototype/models/video/resnet.py +++ /dev/null @@ -1,152 +0,0 @@ -from functools import partial -from typing import Any, Callable, List, Optional, Sequence, Type, Union - -from torch import nn -from torchvision.prototype.transforms import VideoClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.video.resnet import ( - BasicBlock, - BasicStem, - Bottleneck, - Conv2Plus1D, - Conv3DSimple, - Conv3DNoTemporal, - R2Plus1dStem, - VideoResNet, -) -from .._api import WeightsEnum, Weights -from .._meta import _KINETICS400_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "VideoResNet", - "R3D_18_Weights", - "MC3_18_Weights", - "R2Plus1D_18_Weights", - "r3d_18", - "mc3_18", - "r2plus1d_18", -] - - -def _video_resnet( - block: Type[Union[BasicBlock, Bottleneck]], - conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], - layers: List[int], - stem: Callable[..., nn.Module], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> VideoResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = VideoResNet(block, conv_makers, layers, stem, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "video_classification", - "publication_year": 2017, - "size": (112, 112), - "min_size": (1, 1), - "categories": _KINETICS400_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", -} - - -class R3D_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "R3D", - "num_params": 33371472, - "acc@1": 52.75, - "acc@5": 75.45, - }, - ) - DEFAULT = KINETICS400_V1 - - -class MC3_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "MC3", - "num_params": 11695440, - "acc@1": 53.90, - "acc@5": 76.29, - }, - ) - DEFAULT = KINETICS400_V1 - - -class R2Plus1D_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "R(2+1)D", - "num_params": 31505325, - "acc@1": 57.50, - "acc@5": 78.81, - }, - ) - DEFAULT = KINETICS400_V1 - - -@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) -def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = R3D_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv3DSimple] * 4, - [2, 2, 2, 2], - BasicStem, - weights, - progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) -def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = MC3_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] - [2, 2, 2, 2], - BasicStem, - weights, - progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) -def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = R2Plus1D_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv2Plus1D] * 4, - [2, 2, 2, 2], - R2Plus1dStem, - weights, - progress, - **kwargs, - ) diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py deleted file mode 100644 index 468903b6b94..00000000000 --- a/torchvision/prototype/models/vision_transformer.py +++ /dev/null @@ -1,198 +0,0 @@ -# References: -# https://github.com/google-research/vision_transformer -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py - -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - -__all__ = [ - "VisionTransformer", - "ViT_B_16_Weights", - "ViT_B_32_Weights", - "ViT_L_16_Weights", - "ViT_L_32_Weights", - "vit_b_16", - "vit_b_32", - "vit_l_16", - "vit_l_32", -] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ViT", - "publication_year": 2020, - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class ViT_B_16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 86567656, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", - "acc@1": 81.072, - "acc@5": 95.318, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_B_32_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 88224232, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", - "acc@1": 75.912, - "acc@5": 92.466, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_L_16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), - meta={ - **_COMMON_META, - "num_params": 304326632, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", - "acc@1": 79.662, - "acc@5": 94.638, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_L_32_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 306535400, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", - "acc@1": 76.972, - "acc@5": 93.07, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -def _vision_transformer( - patch_size: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> VisionTransformer: - image_size = kwargs.pop("image_size", 224) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = VisionTransformer( - image_size=image_size, - patch_size=patch_size, - num_layers=num_layers, - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_dim=mlp_dim, - **kwargs, - ) - - if weights: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) -def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_B_16_Weights.verify(weights) - - return _vision_transformer( - patch_size=16, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) -def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_B_32_Weights.verify(weights) - - return _vision_transformer( - patch_size=32, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) -def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_L_16_Weights.verify(weights) - - return _vision_transformer( - patch_size=16, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) -def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_L_32_Weights.verify(weights) - - return _vision_transformer( - patch_size=32, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, - weights=weights, - progress=progress, - **kwargs, - ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 81e914e8383..b0860cbc787 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,5 +1,3 @@ -from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip - from . import functional # usort: skip from ._transform import Transform # usort: skip @@ -21,11 +19,4 @@ ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda -from ._presets import ( - ObjectDetectionEval, - ImageClassificationEval, - SemanticSegmentationEval, - VideoClassificationEval, - OpticalFlowEval, -) from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index c451feb9a32..7fc62423ab8 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -4,9 +4,10 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F +from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.utils._internal import query_recursively -from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from torchvision.transforms.autoaugment import AutoAugmentPolicy +from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode from ._utils import get_image_dimensions, is_simple_tensor diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 5cd0c16ee19..0487a71416e 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -7,8 +7,8 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F -from torchvision.transforms.functional import pil_to_tensor +from torchvision.prototype.transforms import Transform, functional as F +from torchvision.transforms.functional import pil_to_tensor, InterpolationMode from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from typing_extensions import Literal diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 8718c381525..ecf0d31df3a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,9 +4,8 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import InterpolationMode from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix +from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 147a7f0ff4c..fe5284394cb 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -1,14 +1,11 @@ import collections.abc import difflib -import functools -import inspect import io import mmap import os import os.path import platform import textwrap -import warnings from typing import ( Any, BinaryIO, @@ -36,7 +33,6 @@ "FrozenMapping", "make_repr", "FrozenBunch", - "kwonly_to_pos_or_kw", "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", @@ -140,57 +136,6 @@ def __repr__(self) -> str: return make_repr(type(self).__name__, self.items()) -def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: - """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. - - For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: - - .. code:: - - def old_fn(foo, bar, baz=None): - ... - - def new_fn(foo, *, bar, baz=None): - ... - - Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC - and at the same time warn the user of the deprecation, this decorator can be used: - - .. code:: - - @kwonly_to_pos_or_kw - def new_fn(foo, *, bar, baz=None): - ... - - new_fn("foo", "bar, "baz") - """ - params = inspect.signature(fn).parameters - - try: - keyword_only_start_idx = next( - idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY - ) - except StopIteration: - raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None - - keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> D: - args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] - if keyword_only_args: - keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) - warnings.warn( - f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " - f"parameter(s) is deprecated. Please use keyword parameter(s) instead." - ) - kwargs.update(keyword_only_kwargs) - - return fn(*args, **kwargs) - - return wrapper - - def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable return bytearray(file.read(-1 if count == -1 else count * item_size)) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/transforms/_presets.py similarity index 53% rename from torchvision/prototype/transforms/_presets.py rename to torchvision/transforms/_presets.py index 3ab045b3ddb..4d503f44cc5 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -1,32 +1,35 @@ -from typing import Dict, Optional, Tuple +""" +This file is part of the private API. Please do not use directly these classes as they will be modified on +future versions without warning. The classes should be accessed only via the transforms argument of Weights. +""" +from typing import Optional, Tuple import torch from torch import Tensor, nn -from ...transforms import functional as F, InterpolationMode +from . import functional as F, InterpolationMode __all__ = [ - "ObjectDetectionEval", - "ImageClassificationEval", - "VideoClassificationEval", - "SemanticSegmentationEval", - "OpticalFlowEval", + "ObjectDetection", + "ImageClassification", + "VideoClassification", + "SemanticSegmentation", + "OpticalFlow", ] -class ObjectDetectionEval(nn.Module): - def forward( - self, img: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: +class ObjectDetection(nn.Module): + def forward(self, img: Tensor) -> Tensor: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) - return F.convert_image_dtype(img, torch.float), target + return F.convert_image_dtype(img, torch.float) -class ImageClassificationEval(nn.Module): +class ImageClassification(nn.Module): def __init__( self, + *, crop_size: int, resize_size: int = 256, mean: Tuple[float, ...] = (0.485, 0.456, 0.406), @@ -50,9 +53,10 @@ def forward(self, img: Tensor) -> Tensor: return img -class VideoClassificationEval(nn.Module): +class VideoClassification(nn.Module): def __init__( self, + *, crop_size: Tuple[int, int], resize_size: Tuple[int, int], mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), @@ -67,53 +71,61 @@ def __init__( self._interpolation = interpolation def forward(self, vid: Tensor) -> Tensor: - vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W) + need_squeeze = False + if vid.ndim < 5: + vid = vid.unsqueeze(dim=0) + need_squeeze = True + + vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W) + N, T, C, H, W = vid.shape + vid = vid.view(-1, C, H, W) vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.center_crop(vid, self._crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self._mean, std=self._std) - return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) + H, W = self._crop_size + vid = vid.view(N, T, C, H, W) + vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) + + if need_squeeze: + vid = vid.squeeze(dim=0) + return vid -class SemanticSegmentationEval(nn.Module): +class SemanticSegmentation(nn.Module): def __init__( self, - resize_size: int, + *, + resize_size: Optional[int], mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, - interpolation_target: InterpolationMode = InterpolationMode.NEAREST, ) -> None: super().__init__() - self._size = [resize_size] + self._size = [resize_size] if resize_size is not None else None self._mean = list(mean) self._std = list(std) self._interpolation = interpolation - self._interpolation_target = interpolation_target - def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: - img = F.resize(img, self._size, interpolation=self._interpolation) + def forward(self, img: Tensor) -> Tensor: + if isinstance(self._size, list): + img = F.resize(img, self._size, interpolation=self._interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) - if target: - target = F.resize(target, self._size, interpolation=self._interpolation_target) - if not isinstance(target, Tensor): - target = F.pil_to_tensor(target) - target = target.squeeze(0).to(torch.int64) - return img, target - + return img -class OpticalFlowEval(nn.Module): - def forward( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) +class OpticalFlow(nn.Module): + def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: + if not isinstance(img1, Tensor): + img1 = F.pil_to_tensor(img1) + if not isinstance(img2, Tensor): + img2 = F.pil_to_tensor(img2) - img1 = F.convert_image_dtype(img1, torch.float32) - img2 = F.convert_image_dtype(img2, torch.float32) + img1 = F.convert_image_dtype(img1, torch.float) + img2 = F.convert_image_dtype(img2, torch.float) # map [0, 1] into [-1, 1] img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) @@ -122,19 +134,4 @@ def forward( img1 = img1.contiguous() img2 = img2.contiguous() - return img1, img2, flow, valid_flow_mask - - def _pil_or_numpy_to_tensor( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - if not isinstance(img1, Tensor): - img1 = F.pil_to_tensor(img1) - if not isinstance(img2, Tensor): - img2 = F.pil_to_tensor(img2) - - if flow is not None and not isinstance(flow, Tensor): - flow = torch.from_numpy(flow) - if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor): - valid_flow_mask = torch.from_numpy(valid_flow_mask) - - return img1, img2, flow, valid_flow_mask + return img1, img2 From f7d0a50cd59e925b2d260cab485857c1d81db933 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 23 Mar 2022 14:07:30 +0530 Subject: [PATCH 09/20] Apply suggestions from code review Co-authored-by: Philip Meier --- torchvision/prototype/datasets/_builtin/usps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 4e302036bfa..d5da59b2346 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -53,7 +53,7 @@ def _make_info(self) -> DatasetInfo: def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [USPS._RESOURCES[config.split]] - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: image, label = data return dict( image=Image(image), @@ -69,4 +69,4 @@ def _make_datapipe( dp = USPSFileReader(resource_dps[0]) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._prepare_sample, config=config)) + return Mapper(dp, self._prepare_sample) From 85bc5fd9c57cd393a099f8ecc0cadd382e2a7f3c Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 23 Mar 2022 16:38:28 +0530 Subject: [PATCH 10/20] use decompressor for extracting bz2 --- .../prototype/datasets/_builtin/usps.py | 41 +++++++------------ 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index d5da59b2346..b4d6c055639 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -1,33 +1,13 @@ -import bz2 -import functools -from typing import Any, Dict, List, Tuple, BinaryIO, Iterator +from typing import Any, Dict, List, Tuple import numpy as np import torch -from torchdata.datapipes.iter import IterDataPipe, IterableWrapper, LineReader, Mapper +from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label -class USPSFileReader(IterDataPipe[torch.Tensor]): - def __init__(self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]]) -> None: - self.datapipe = datapipe - - def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: - for path, _ in self.datapipe: - with bz2.open(path) as fp: - datapipe = IterableWrapper([(path, fp)]) - line_reader = LineReader(datapipe, decode=True) - for _, line in line_reader: - raw_data = line.split() - tmp_list = [x.split(":")[-1] for x in raw_data[1:]] - img = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) - img = ((img + 1) / 2 * 255).astype(dtype=np.uint8) - target = int(raw_data[0]) - 1 - yield torch.from_numpy(img), torch.tensor(target) - - class USPS(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( @@ -54,10 +34,18 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [USPS._RESOURCES[config.split]] def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - image, label = data + _filename, line = data + + raw_data = line.split() + tmp_list = [x.split(":")[-1] for x in raw_data[1:]] + img = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) + img = ((img + 1) / 2 * 255).astype(dtype=np.uint8) + img = torch.from_numpy(img) + target = int(raw_data[0]) - 1 + return dict( - image=Image(image), - label=Label(label, dtype=torch.int64, categories=self.categories), + image=Image(img), + label=Label(target, dtype=torch.int64, categories=self.categories), ) def _make_datapipe( @@ -66,7 +54,8 @@ def _make_datapipe( *, config: DatasetConfig, ) -> IterDataPipe[Dict[str, Any]]: - dp = USPSFileReader(resource_dps[0]) + dp = Decompressor(resource_dps[0]) + dp = LineReader(dp, decode=True) dp = hint_sharding(dp) dp = hint_shuffling(dp) return Mapper(dp, self._prepare_sample) From 68f15ba23cc6c44982836816499c4a52674446b4 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 23 Mar 2022 22:16:24 +0530 Subject: [PATCH 11/20] Apply suggestions from code review Co-authored-by: Philip Meier --- torchvision/prototype/datasets/_builtin/usps.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index b4d6c055639..5f8e552ec5d 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -33,8 +33,7 @@ def _make_info(self) -> DatasetInfo: def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [USPS._RESOURCES[config.split]] - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - _filename, line = data + def _prepare_sample(self, line: str) -> Dict[str, Any]: raw_data = line.split() tmp_list = [x.split(":")[-1] for x in raw_data[1:]] @@ -55,7 +54,7 @@ def _make_datapipe( config: DatasetConfig, ) -> IterDataPipe[Dict[str, Any]]: dp = Decompressor(resource_dps[0]) - dp = LineReader(dp, decode=True) + dp = LineReader(dp, decode=True, return_path=False) dp = hint_sharding(dp) dp = hint_shuffling(dp) return Mapper(dp, self._prepare_sample) From aefad05cc34648cef7de23c913ce3f804110d8a0 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 23 Mar 2022 22:33:39 +0530 Subject: [PATCH 12/20] Apply suggestions from code review Co-authored-by: Philip Meier --- torchvision/prototype/datasets/_builtin/usps.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 5f8e552ec5d..f63672f43a9 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -34,17 +34,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [USPS._RESOURCES[config.split]] def _prepare_sample(self, line: str) -> Dict[str, Any]: - - raw_data = line.split() - tmp_list = [x.split(":")[-1] for x in raw_data[1:]] - img = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) - img = ((img + 1) / 2 * 255).astype(dtype=np.uint8) - img = torch.from_numpy(img) - target = int(raw_data[0]) - 1 - + label, *values = line.strip().split(" ") + values = [float(value.split(":")[1]) for value in values] + pixels = torch.tensor(values).add_(1).div_(2) return dict( - image=Image(img), - label=Label(target, dtype=torch.int64, categories=self.categories), + image=Image(pixels.reshape(16, 16)), + label=Label(int(label) - 1, categories=self.categories), ) def _make_datapipe( From 81cfba99aa2efeb1e1ff75a8435b73cfb56c5877 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 24 Mar 2022 11:31:17 +0530 Subject: [PATCH 13/20] fixed lint fails --- torchvision/prototype/datasets/_builtin/usps.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index f63672f43a9..5df0978d031 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List -import numpy as np import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource From e847fde7401bae1a6dd6877a0700f85d651e1fd0 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 24 Mar 2022 13:24:49 +0530 Subject: [PATCH 14/20] added tests for USPS --- test/builtin_dataset_mocks.py | 77 +++++++++++++++++++++++++ test/test_prototype_builtin_datasets.py | 23 +++++++- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 62259a604a0..db98d9399df 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1,3 +1,4 @@ +import bz2 import collections.abc import csv import functools @@ -1431,3 +1432,79 @@ def stanford_cars(info, root, config): make_tar(root, "car_devkit.tgz", devkit, compression="gz") return num_samples + + +class USPSMockData: + @classmethod + def generate_images(cls, num_samples, shape, image_dtype, low, high): + return make_tensor((num_samples, *shape), dtype=image_dtype, low=low, high=high) + + @classmethod + def generate_labels(cls, num_samples, label_dtype, low, high): + return make_tensor((num_samples,), dtype=label_dtype, low=low, high=high) + + @classmethod + def generate( + cls, + root, + *, + num_categories, + num_samples=None, + data_file, + image_size=(16 * 16,), + image_dtype=torch.half, + label_dtype=torch.uint8, + compressor=None, + ): + if num_samples is None: + num_samples = len(num_categories) + if compressor is None: + compressor = bz2.open + + image_data = cls.generate_images( + num_samples=num_samples, + shape=image_size, + image_dtype=image_dtype, + low=-1, + high=1, + ) + + label_data = cls.generate_labels( + num_samples=num_samples, + label_dtype=label_dtype, + low=0, + high=len(num_categories), + ) + + cls._create_binary_file( + root, + data_file, + image_data=image_data, + label_data=label_data, + compressor=compressor, + ) + + return num_samples + + @classmethod + def _create_binary_file(cls, root, data_file, image_data, label_data, compressor): + with compressor(root / data_file, "wb") as f: + for image, label in zip(image_data, label_data): + encoded_label = str(label.item()) + encoded_image = cls.encode_image(image) + encoded_bytes = bytes(f"{encoded_label} {encoded_image} \n", encoding="utf-8") + f.write(encoded_bytes) + + @classmethod + def encode_image(cls, image): + data = [f"{i}:{value.item()}" for i, value in enumerate(image, start=1)] + return " ".join(data) + + +@register_mock +def USPS(info, root, config): + num_samples = {"train": 15, "test": 7}[config["split"]] + + train = config.split == "train" + data_file = f"usps.{'t.' if not train else ''}bz2" + return USPSMockData.generate(root, num_categories=info.categories, num_samples=num_samples, data_file=data_file) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index f7c40d432a4..8270774e0b0 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -12,7 +12,7 @@ from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets - +from torchvision.prototype.features import Image, Label assert_samples_equal = functools.partial( assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True @@ -180,3 +180,24 @@ def test_label_matches_path(self, test_home, dataset_mock, config): for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) assert sample["label"] == label_from_path + + +@parametrize_dataset_mocks(DATASET_MOCKS["USPS"]) +class TestUSPS: + def test_label_matches_path(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + + dataset = datasets.load(dataset_mock.name, **config) + + for sample in dataset: + # check if correct keys exist + assert "image" in sample + assert "label" in sample + + # check if correct instance type + assert isinstance(sample["image"], Image) + assert isinstance(sample["label"], Label) + + # check is data type is correct + assert isinstance(sample["image"].data, torch.FloatTensor) + assert isinstance(sample["label"].data, torch.LongTensor) From fe7c573f3516da3a2bc83658fb6a6c4b8a39a68c Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 24 Mar 2022 13:27:02 +0530 Subject: [PATCH 15/20] check image shape --- test/test_prototype_builtin_datasets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 8270774e0b0..1dad70774da 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -201,3 +201,6 @@ def test_label_matches_path(self, test_home, dataset_mock, config): # check is data type is correct assert isinstance(sample["image"].data, torch.FloatTensor) assert isinstance(sample["label"].data, torch.LongTensor) + + # verify image size is (1, 16, 16 + assert sample["image"].data.shape == (1, 16, 16) From b1deb6347207ab695c6b092687a8b74337fc9bcc Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 24 Mar 2022 13:38:45 +0530 Subject: [PATCH 16/20] fix tests --- test/builtin_dataset_mocks.py | 2 +- test/test_prototype_builtin_datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index db98d9399df..85e81526d44 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1502,7 +1502,7 @@ def encode_image(cls, image): @register_mock -def USPS(info, root, config): +def usps(info, root, config): num_samples = {"train": 15, "test": 7}[config["split"]] train = config.split == "train" diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 1dad70774da..1ec049b045b 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -182,7 +182,7 @@ def test_label_matches_path(self, test_home, dataset_mock, config): assert sample["label"] == label_from_path -@parametrize_dataset_mocks(DATASET_MOCKS["USPS"]) +@parametrize_dataset_mocks(DATASET_MOCKS["usps"]) class TestUSPS: def test_label_matches_path(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) From 0624de9c977474156ada7481871b8daff1d35bda Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 24 Mar 2022 14:45:43 +0530 Subject: [PATCH 17/20] check shape on image directly --- test/test_prototype_builtin_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 1ec049b045b..06c182575e5 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -202,5 +202,5 @@ def test_label_matches_path(self, test_home, dataset_mock, config): assert isinstance(sample["image"].data, torch.FloatTensor) assert isinstance(sample["label"].data, torch.LongTensor) - # verify image size is (1, 16, 16 - assert sample["image"].data.shape == (1, 16, 16) + # verify image size is (1, 16, 16) + assert sample["image"].shape == (1, 16, 16) From d5d9386978200c68256d103721f932abe5256ee5 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 25 Mar 2022 15:06:35 +0530 Subject: [PATCH 18/20] Apply suggestions from code review Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 84 ++++++----------------------------- 1 file changed, 13 insertions(+), 71 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 85e81526d44..1153c1b33f0 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1434,77 +1434,19 @@ def stanford_cars(info, root, config): return num_samples -class USPSMockData: - @classmethod - def generate_images(cls, num_samples, shape, image_dtype, low, high): - return make_tensor((num_samples, *shape), dtype=image_dtype, low=low, high=high) - - @classmethod - def generate_labels(cls, num_samples, label_dtype, low, high): - return make_tensor((num_samples,), dtype=label_dtype, low=low, high=high) - - @classmethod - def generate( - cls, - root, - *, - num_categories, - num_samples=None, - data_file, - image_size=(16 * 16,), - image_dtype=torch.half, - label_dtype=torch.uint8, - compressor=None, - ): - if num_samples is None: - num_samples = len(num_categories) - if compressor is None: - compressor = bz2.open - - image_data = cls.generate_images( - num_samples=num_samples, - shape=image_size, - image_dtype=image_dtype, - low=-1, - high=1, - ) - - label_data = cls.generate_labels( - num_samples=num_samples, - label_dtype=label_dtype, - low=0, - high=len(num_categories), - ) - - cls._create_binary_file( - root, - data_file, - image_data=image_data, - label_data=label_data, - compressor=compressor, - ) - - return num_samples - - @classmethod - def _create_binary_file(cls, root, data_file, image_data, label_data, compressor): - with compressor(root / data_file, "wb") as f: - for image, label in zip(image_data, label_data): - encoded_label = str(label.item()) - encoded_image = cls.encode_image(image) - encoded_bytes = bytes(f"{encoded_label} {encoded_image} \n", encoding="utf-8") - f.write(encoded_bytes) - - @classmethod - def encode_image(cls, image): - data = [f"{i}:{value.item()}" for i, value in enumerate(image, start=1)] - return " ".join(data) - - @register_mock def usps(info, root, config): - num_samples = {"train": 15, "test": 7}[config["split"]] + num_samples = {"train": 15, "test": 7}[config.split] - train = config.split == "train" - data_file = f"usps.{'t.' if not train else ''}bz2" - return USPSMockData.generate(root, num_categories=info.categories, num_samples=num_samples, data_file=data_file) + with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh: + lines = [] + for _ in range(num_samples): + label = make_tensor(1, low=1, high=11, dtype=torch.int) + values = make_tensor(256, low=-1, high=1, dtype=torch.float) + lines.append( + " ".join([f"{int(label)}", *(f"{idx}:{float(value):.6f}" for idx, value in enumerate(values, 1))]) + ) + + fh.write("\n".join(lines).encode()) + + return num_samples From 73248616aa165679b20e52f29b435e89e158a61f Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 25 Mar 2022 14:52:25 +0530 Subject: [PATCH 19/20] removed test and comments --- test/test_prototype_builtin_datasets.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 06c182575e5..92f147b8825 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -190,17 +190,10 @@ def test_label_matches_path(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: - # check if correct keys exist assert "image" in sample assert "label" in sample - # check if correct instance type assert isinstance(sample["image"], Image) assert isinstance(sample["label"], Label) - # check is data type is correct - assert isinstance(sample["image"].data, torch.FloatTensor) - assert isinstance(sample["label"].data, torch.LongTensor) - - # verify image size is (1, 16, 16) assert sample["image"].shape == (1, 16, 16) From d23863cccc1f6482ba2624d2db521e7a2faf52af Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 25 Mar 2022 11:56:10 +0100 Subject: [PATCH 20/20] Update test/test_prototype_builtin_datasets.py Co-authored-by: Nicolas Hug --- test/test_prototype_builtin_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 92f147b8825..f414f4e48cd 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -184,7 +184,7 @@ def test_label_matches_path(self, test_home, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["usps"]) class TestUSPS: - def test_label_matches_path(self, test_home, dataset_mock, config): + def test_sample_content(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config)