Skip to content

Porting docs, examples, tutorials and galleries #5620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 15, 2022
13 changes: 10 additions & 3 deletions android/test_app/make_assets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchvision.models.detection import (
fasterrcnn_mobilenet_v3_large_320_fpn,
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights,
)

print(torch.__version__)

model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
model = fasterrcnn_mobilenet_v3_large_320_fpn(
weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT,
box_score_thresh=0.7,
rpn_post_nms_top_n_test=100,
rpn_score_thresh=0.4,
rpn_pre_nms_top_n_test=150,
)

model.eval()
Expand Down
2 changes: 1 addition & 1 deletion examples/cpp/hello_world/trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))

model = torchvision.models.resnet18(pretrained=False)
model = torchvision.models.resnet18()
model.eval()

traced_model = torch.jit.script(model)
Expand Down
29 changes: 13 additions & 16 deletions gallery/plot_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = "tight"
Expand Down Expand Up @@ -88,24 +87,19 @@ def plot(imgs, **imshow_kwargs):
# reduce the image sizes for the example to run faster. Image dimension must be
# divisible by 8.

from torchvision.models.optical_flow import Raft_Large_Weights

def preprocess(batch):
transforms = T.Compose(
[
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.Resize(size=(520, 960)),
]
)
batch = transforms(batch)
return batch
weights = Raft_Large_Weights.DEFAULT
transforms = weights.transforms()


# If you can, run this example on a GPU, it will be a lot faster.
device = "cuda" if torch.cuda.is_available() else "cpu"
def preprocess(img1_batch, img2_batch):
batch1 = F.resize(img1_batch, size=[520, 960])
batch2 = F.resize(img2_batch, size=[520, 960])
return transforms(img1_batch, img2_batch)


img1_batch = preprocess(img1_batch).to(device)
img2_batch = preprocess(img2_batch).to(device)
img1_batch, img2_batch = preprocess(img1_batch, img2_batch)

print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}")

Expand All @@ -121,7 +115,10 @@ def preprocess(batch):

from torchvision.models.optical_flow import raft_large

model = raft_large(pretrained=True, progress=False).to(device)
# If you can, run this example on a GPU, it will be a lot faster.
device = "cuda" if torch.cuda.is_available() else "cpu"

model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
model = model.eval()

list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
Expand Down
8 changes: 5 additions & 3 deletions gallery/plot_repurposing_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,14 @@ def show(imgs):
# Here is demo with a Faster R-CNN model loaded from
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`

from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights

model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
print(img.size())

img = F.convert_image_dtype(img, torch.float)
tranforms = weights.transforms()
img = tranforms(img)
target = {}
target["boxes"] = boxes
target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)
Expand Down
12 changes: 4 additions & 8 deletions gallery/plot_scripted_tensor_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,16 @@ def show(imgs):
# Let's define a ``Predictor`` module that transforms the input tensor and then
# applies an ImageNet model on it.

from torchvision.models import resnet18
from torchvision.models import resnet18, ResNet18_Weights


class Predictor(nn.Module):

def __init__(self):
super().__init__()
self.resnet18 = resnet18(pretrained=True, progress=False).eval()
self.transforms = nn.Sequential(
T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions
T.CenterCrop(224),
T.ConvertImageDtype(torch.float),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)
weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms()

def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
Expand Down
39 changes: 27 additions & 12 deletions gallery/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,17 @@ def show(imgs):
# :func:`~torchvision.models.detection.ssd300_vgg16`. For more details
# on the output of such models, you may refer to :ref:`instance_seg_output`.

from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms.functional import convert_image_dtype
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights


batch_int = torch.stack([dog1_int, dog2_int])
batch = convert_image_dtype(batch_int, dtype=torch.float)

model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()

batch = transforms(batch_int)

model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()

outputs = model(batch)
Expand Down Expand Up @@ -120,13 +123,15 @@ def show(imgs):
# images must be normalized before they're passed to a semantic segmentation
# model.

from torchvision.models.segmentation import fcn_resnet50
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights

weights = FCN_ResNet50_Weights.DEFAULT
transforms = weights.transforms()

model = fcn_resnet50(pretrained=True, progress=False)
model = fcn_resnet50(weights=weights, progress=False)
model = model.eval()

normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
normalized_batch = transforms(batch)
output = model(normalized_batch)['out']
print(output.shape, output.min().item(), output.max().item())

Expand Down Expand Up @@ -262,8 +267,14 @@ def show(imgs):
# of them may not have masks, like
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`.

from torchvision.models.detection import maskrcnn_resnet50_fpn
model = maskrcnn_resnet50_fpn(pretrained=True, progress=False)
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()

batch = transforms(batch_int)

model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()

output = model(batch)
Expand Down Expand Up @@ -378,13 +389,17 @@ def show(imgs):
# Note that the keypoint detection model does not need normalized images.
#

from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
from torchvision.io import read_image

person_int = read_image(str(Path("assets") / "person1.jpg"))
person_float = convert_image_dtype(person_int, dtype=torch.float)

model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False)
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()

person_float = transforms(person_int)

model = keypointrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()

outputs = model([person_float])
Expand Down
13 changes: 10 additions & 3 deletions ios/VisionTestApp/make_assets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchvision.models.detection import (
fasterrcnn_mobilenet_v3_large_320_fpn,
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights,
)

print(torch.__version__)

model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
model = fasterrcnn_mobilenet_v3_large_320_fpn(
weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT,
box_score_thresh=0.7,
rpn_post_nms_top_n_test=100,
rpn_score_thresh=0.4,
rpn_pre_nms_top_n_test=150,
)

model.eval()
Expand Down
2 changes: 1 addition & 1 deletion test/tracing/frcnn/trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn()
model.eval()

traced_model = torch.jit.script(model)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor,

class OpticalFlowEval(nn.Module):
def forward(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor] = None, valid_flow_mask: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:

img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask)
Expand Down