Skip to content

Commit 6d96ed5

Browse files
datumboxNicolasHug
andauthored
Porting docs, examples, tutorials and galleries (#5620)
* Fix examples, tutorials and gallery * Update gallery/plot_optical_flow.py Co-authored-by: Nicolas Hug <[email protected]> * Fix import * Revert hardcoded normalization * fix uncommitted changes * Fix bug * Fix more bugs * Making resize optional for segmentation * Fixing preset * Fix mypy * Fixing documentation strings * Fix flake8 * minor refactoring Co-authored-by: Nicolas Hug <[email protected]>
1 parent 5a96c9a commit 6d96ed5

20 files changed

+115
-81
lines changed

android/test_app/make_assets.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import torch
2-
import torchvision
32
from torch.utils.mobile_optimizer import optimize_for_mobile
3+
from torchvision.models.detection import (
4+
fasterrcnn_mobilenet_v3_large_320_fpn,
5+
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights,
6+
)
47

58
print(torch.__version__)
69

7-
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
8-
pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
10+
model = fasterrcnn_mobilenet_v3_large_320_fpn(
11+
weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT,
12+
box_score_thresh=0.7,
13+
rpn_post_nms_top_n_test=100,
14+
rpn_score_thresh=0.4,
15+
rpn_pre_nms_top_n_test=150,
916
)
1017

1118
model.eval()

examples/cpp/hello_world/trace_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
HERE = osp.dirname(osp.abspath(__file__))
77
ASSETS = osp.dirname(osp.dirname(HERE))
88

9-
model = torchvision.models.resnet18(pretrained=False)
9+
model = torchvision.models.resnet18()
1010
model.eval()
1111

1212
traced_model = torch.jit.script(model)

gallery/plot_optical_flow.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
import matplotlib.pyplot as plt
2121
import torchvision.transforms.functional as F
22-
import torchvision.transforms as T
2322

2423

2524
plt.rcParams["savefig.bbox"] = "tight"
@@ -88,24 +87,19 @@ def plot(imgs, **imshow_kwargs):
8887
# reduce the image sizes for the example to run faster. Image dimension must be
8988
# divisible by 8.
9089

90+
from torchvision.models.optical_flow import Raft_Large_Weights
9191

92-
def preprocess(batch):
93-
transforms = T.Compose(
94-
[
95-
T.ConvertImageDtype(torch.float32),
96-
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
97-
T.Resize(size=(520, 960)),
98-
]
99-
)
100-
batch = transforms(batch)
101-
return batch
92+
weights = Raft_Large_Weights.DEFAULT
93+
transforms = weights.transforms()
10294

10395

104-
# If you can, run this example on a GPU, it will be a lot faster.
105-
device = "cuda" if torch.cuda.is_available() else "cpu"
96+
def preprocess(img1_batch, img2_batch):
97+
img1_batch = F.resize(img1_batch, size=[520, 960])
98+
img2_batch = F.resize(img2_batch, size=[520, 960])
99+
return transforms(img1_batch, img2_batch)[:2]
100+
106101

107-
img1_batch = preprocess(img1_batch).to(device)
108-
img2_batch = preprocess(img2_batch).to(device)
102+
img1_batch, img2_batch = preprocess(img1_batch, img2_batch)
109103

110104
print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}")
111105

@@ -121,7 +115,10 @@ def preprocess(batch):
121115

122116
from torchvision.models.optical_flow import raft_large
123117

124-
model = raft_large(pretrained=True, progress=False).to(device)
118+
# If you can, run this example on a GPU, it will be a lot faster.
119+
device = "cuda" if torch.cuda.is_available() else "cpu"
120+
121+
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
125122
model = model.eval()
126123

127124
list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
@@ -182,10 +179,9 @@ def preprocess(batch):
182179
# from torchvision.io import write_jpeg
183180
# for i, (img1, img2) in enumerate(zip(frames, frames[1:])):
184181
# # Note: it would be faster to predict batches of flows instead of individual flows
185-
# img1 = preprocess(img1[None]).to(device)
186-
# img2 = preprocess(img2[None]).to(device)
182+
# img1, img2 = preprocess(img1, img2)
187183

188-
# list_of_flows = model(img1_batch, img2_batch)
184+
# list_of_flows = model(img1.to(device), img1.to(device))
189185
# predicted_flow = list_of_flows[-1][0]
190186
# flow_img = flow_to_image(predicted_flow).to("cpu")
191187
# output_folder = "/tmp/" # Update this to the folder of your choice

gallery/plot_repurposing_annotations.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,14 @@ def show(imgs):
139139
# Here is demo with a Faster R-CNN model loaded from
140140
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
141141

142-
from torchvision.models.detection import fasterrcnn_resnet50_fpn
142+
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
143143

144-
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
144+
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
145+
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
145146
print(img.size())
146147

147-
img = F.convert_image_dtype(img, torch.float)
148+
tranforms = weights.transforms()
149+
img, _ = tranforms(img)
148150
target = {}
149151
target["boxes"] = boxes
150152
target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)

gallery/plot_scripted_tensor_transforms.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,16 @@ def show(imgs):
8585
# Let's define a ``Predictor`` module that transforms the input tensor and then
8686
# applies an ImageNet model on it.
8787

88-
from torchvision.models import resnet18
88+
from torchvision.models import resnet18, ResNet18_Weights
8989

9090

9191
class Predictor(nn.Module):
9292

9393
def __init__(self):
9494
super().__init__()
95-
self.resnet18 = resnet18(pretrained=True, progress=False).eval()
96-
self.transforms = nn.Sequential(
97-
T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions
98-
T.CenterCrop(224),
99-
T.ConvertImageDtype(torch.float),
100-
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
101-
)
95+
weights = ResNet18_Weights.DEFAULT
96+
self.resnet18 = resnet18(weights=weights, progress=False).eval()
97+
self.transforms = weights.transforms()
10298

10399
def forward(self, x: torch.Tensor) -> torch.Tensor:
104100
with torch.no_grad():

gallery/plot_visualization_utils.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,17 @@ def show(imgs):
7373
# :func:`~torchvision.models.detection.ssd300_vgg16`. For more details
7474
# on the output of such models, you may refer to :ref:`instance_seg_output`.
7575

76-
from torchvision.models.detection import fasterrcnn_resnet50_fpn
77-
from torchvision.transforms.functional import convert_image_dtype
76+
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
7877

7978

8079
batch_int = torch.stack([dog1_int, dog2_int])
81-
batch = convert_image_dtype(batch_int, dtype=torch.float)
8280

83-
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
81+
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
82+
transforms = weights.transforms()
83+
84+
batch, _ = transforms(batch_int)
85+
86+
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
8487
model = model.eval()
8588

8689
outputs = model(batch)
@@ -120,13 +123,15 @@ def show(imgs):
120123
# images must be normalized before they're passed to a semantic segmentation
121124
# model.
122125

123-
from torchvision.models.segmentation import fcn_resnet50
126+
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
124127

128+
weights = FCN_ResNet50_Weights.DEFAULT
129+
transforms = weights.transforms(resize_size=None)
125130

126-
model = fcn_resnet50(pretrained=True, progress=False)
131+
model = fcn_resnet50(weights=weights, progress=False)
127132
model = model.eval()
128133

129-
normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
134+
normalized_batch, _ = transforms(batch)
130135
output = model(normalized_batch)['out']
131136
print(output.shape, output.min().item(), output.max().item())
132137

@@ -262,8 +267,14 @@ def show(imgs):
262267
# of them may not have masks, like
263268
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`.
264269

265-
from torchvision.models.detection import maskrcnn_resnet50_fpn
266-
model = maskrcnn_resnet50_fpn(pretrained=True, progress=False)
270+
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
271+
272+
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
273+
transforms = weights.transforms()
274+
275+
batch, _ = transforms(batch_int)
276+
277+
model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
267278
model = model.eval()
268279

269280
output = model(batch)
@@ -378,13 +389,17 @@ def show(imgs):
378389
# Note that the keypoint detection model does not need normalized images.
379390
#
380391

381-
from torchvision.models.detection import keypointrcnn_resnet50_fpn
392+
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
382393
from torchvision.io import read_image
383394

384395
person_int = read_image(str(Path("assets") / "person1.jpg"))
385-
person_float = convert_image_dtype(person_int, dtype=torch.float)
386396

387-
model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False)
397+
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
398+
transforms = weights.transforms()
399+
400+
person_float, _ = transforms(person_int)
401+
402+
model = keypointrcnn_resnet50_fpn(weights=weights, progress=False)
388403
model = model.eval()
389404

390405
outputs = model([person_float])

ios/VisionTestApp/make_assets.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import torch
2-
import torchvision
32
from torch.utils.mobile_optimizer import optimize_for_mobile
3+
from torchvision.models.detection import (
4+
fasterrcnn_mobilenet_v3_large_320_fpn,
5+
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights,
6+
)
47

58
print(torch.__version__)
69

7-
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
8-
pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
10+
model = fasterrcnn_mobilenet_v3_large_320_fpn(
11+
weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT,
12+
box_score_thresh=0.7,
13+
rpn_post_nms_top_n_test=100,
14+
rpn_score_thresh=0.4,
15+
rpn_pre_nms_top_n_test=150,
916
)
1017

1118
model.eval()

test/tracing/frcnn/trace_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
HERE = osp.dirname(osp.abspath(__file__))
77
ASSETS = osp.dirname(osp.dirname(HERE))
88

9-
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
9+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn()
1010
model.eval()
1111

1212
traced_model = torch.jit.script(model)

torchvision/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
3232
3333
Examples::
3434
35-
>>> m = torchvision.models.resnet18(pretrained=True)
35+
>>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
3636
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
3737
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
3838
>>> {'layer1': 'feat1', 'layer3': 'feat2'})

torchvision/models/detection/backbone_utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
77

88
from .. import mobilenet, resnet
9-
from .._utils import IntermediateLayerGetter
9+
from .._api import WeightsEnum
10+
from .._utils import IntermediateLayerGetter, handle_legacy_interface
1011

1112

1213
class BackboneWithFPN(nn.Module):
@@ -55,9 +56,13 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
5556
return x
5657

5758

59+
@handle_legacy_interface(
60+
weights=("pretrained", True), # type: ignore[arg-type]
61+
)
5862
def resnet_fpn_backbone(
63+
*,
5964
backbone_name: str,
60-
pretrained: bool,
65+
weights: Optional[WeightsEnum],
6166
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
6267
trainable_layers: int = 3,
6368
returned_layers: Optional[List[int]] = None,
@@ -69,7 +74,7 @@ def resnet_fpn_backbone(
6974
Examples::
7075
7176
>>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
72-
>>> backbone = resnet_fpn_backbone('resnet50', pretrained=True, trainable_layers=3)
77+
>>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
7378
>>> # get some dummy image
7479
>>> x = torch.rand(1,3,64,64)
7580
>>> # compute the output
@@ -85,7 +90,7 @@ def resnet_fpn_backbone(
8590
Args:
8691
backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
8792
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
88-
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
93+
weights (WeightsEnum, optional): The pretrained weights for the model
8994
norm_layer (callable): it is recommended to use the default value. For details visit:
9095
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
9196
trainable_layers (int): number of trainable (not frozen) layers starting from final block.
@@ -98,7 +103,7 @@ def resnet_fpn_backbone(
98103
a new list of feature maps and their corresponding names. By
99104
default a ``LastLevelMaxPool`` is used.
100105
"""
101-
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
106+
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
102107
return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
103108

104109

@@ -135,13 +140,13 @@ def _resnet_fpn_extractor(
135140

136141

137142
def _validate_trainable_layers(
138-
pretrained: bool,
143+
is_trained: bool,
139144
trainable_backbone_layers: Optional[int],
140145
max_value: int,
141146
default_value: int,
142147
) -> int:
143148
# don't freeze any layers if pretrained model or backbone is not used
144-
if not pretrained:
149+
if not is_trained:
145150
if trainable_backbone_layers is not None:
146151
warnings.warn(
147152
"Changing trainable_backbone_layers has not effect if "
@@ -160,16 +165,20 @@ def _validate_trainable_layers(
160165
return trainable_backbone_layers
161166

162167

168+
@handle_legacy_interface(
169+
weights=("pretrained", True), # type: ignore[arg-type]
170+
)
163171
def mobilenet_backbone(
172+
*,
164173
backbone_name: str,
165-
pretrained: bool,
174+
weights: Optional[WeightsEnum],
166175
fpn: bool,
167176
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
168177
trainable_layers: int = 2,
169178
returned_layers: Optional[List[int]] = None,
170179
extra_blocks: Optional[ExtraFPNBlock] = None,
171180
) -> nn.Module:
172-
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
181+
backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
173182
return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
174183

175184

torchvision/models/detection/faster_rcnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class FasterRCNN(GeneralizedRCNN):
117117
>>> from torchvision.models.detection.rpn import AnchorGenerator
118118
>>> # load a pre-trained model for classification and return
119119
>>> # only the features
120-
>>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
120+
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
121121
>>> # FasterRCNN needs to know the number of
122122
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
123123
>>> # so we need to add it here
@@ -415,7 +415,7 @@ def fasterrcnn_resnet50_fpn(
415415
416416
Example::
417417
418-
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
418+
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
419419
>>> # For training
420420
>>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
421421
>>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
@@ -532,7 +532,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
532532
533533
Example::
534534
535-
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True)
535+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
536536
>>> model.eval()
537537
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
538538
>>> predictions = model(x)
@@ -589,7 +589,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(
589589
590590
Example::
591591
592-
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
592+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
593593
>>> model.eval()
594594
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
595595
>>> predictions = model(x)

0 commit comments

Comments
 (0)