Skip to content

Commit 4fc15fd

Browse files
committed
Allow setting image size in reference training.
1 parent 96931c1 commit 4fc15fd

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

references/detection/presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104
1515
elif data_augmentation == "lsj":
1616
self.transforms = T.Compose(
1717
[
18-
T.ScaleJitter(target_size=(1024, 1024)),
19-
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
2018
T.RandomHorizontalFlip(p=hflip_prob),
19+
T.ScaleJitter(target_size=(1024, 1024)),
20+
T.FixedSizeCrop(size=(1024, 1024), fill=0),
2121
T.PILToTensor(),
2222
T.ConvertImageDtype(torch.float),
2323
]

references/detection/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ def get_args_parser(add_help=True):
126126
parser.add_argument(
127127
"--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
128128
)
129+
parser.add_argument(
130+
"--image-min-size", default=800, type=int, help="resize images so that the smallest side is equal to this"
131+
)
132+
parser.add_argument(
133+
"--image-max-size", default=1333, type=int, help="resize images so that the largest side is less than this"
134+
)
129135
parser.add_argument(
130136
"--sync-bn",
131137
dest="sync_bn",
@@ -210,7 +216,11 @@ def main(args):
210216
)
211217

212218
print("Creating model")
213-
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
219+
kwargs = {
220+
"trainable_backbone_layers": args.trainable_backbone_layers,
221+
"min_size": args.image_min_size,
222+
"max_size": args.image_max_size,
223+
}
214224
if args.data_augmentation in ["multiscale", "lsj"]:
215225
kwargs["_skip_resize"] = True
216226
if "rcnn" in args.model:

torchvision/ops/feature_pyramid_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import OrderedDict
22
from typing import Callable, Dict, List, Optional, Tuple
33

4+
import torch
45
import torch.nn.functional as F
56
from torch import nn, Tensor
67

0 commit comments

Comments
 (0)