Skip to content

Commit c49e346

Browse files
mthrokfacebook-github-bot
authored andcommitted
Add MobileNetV3 architecture for Classification (#3252)
Summary: * Add MobileNetV3 Architecture in TorchVision (#3182) * Adding implementation of network architecture * Adding rmsprop support on the train.py * Adding auto-augment and random-erase in the training scripts. * Adding support for reduced tail on MobileNetV3. * Tagging blocks with comments. * Adding documentation, pre-trained model URL and a minor refactoring. * Handling better untrained supported models. Reviewed By: datumbox Differential Revision: D25954557 fbshipit-source-id: f7d72a81a2ec92cbbbf3bd86c68ae0a426626cc7
1 parent 01dd815 commit c49e346

File tree

10 files changed

+357
-30
lines changed

10 files changed

+357
-30
lines changed

docs/source/models.rst

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ architectures for image classification:
2222
- `Inception`_ v3
2323
- `GoogLeNet`_
2424
- `ShuffleNet`_ v2
25-
- `MobileNet`_ v2
25+
- `MobileNetV2`_
26+
- `MobileNetV3`_
2627
- `ResNeXt`_
2728
- `Wide ResNet`_
2829
- `MNASNet`_
@@ -40,7 +41,9 @@ You can construct a model with random weights by calling its constructor:
4041
inception = models.inception_v3()
4142
googlenet = models.googlenet()
4243
shufflenet = models.shufflenet_v2_x1_0()
43-
mobilenet = models.mobilenet_v2()
44+
mobilenet_v2 = models.mobilenet_v2()
45+
mobilenet_v3_large = models.mobilenet_v3_large()
46+
mobilenet_v3_small = models.mobilenet_v3_small()
4447
resnext50_32x4d = models.resnext50_32x4d()
4548
wide_resnet50_2 = models.wide_resnet50_2()
4649
mnasnet = models.mnasnet1_0()
@@ -59,7 +62,8 @@ These can be constructed by passing ``pretrained=True``:
5962
inception = models.inception_v3(pretrained=True)
6063
googlenet = models.googlenet(pretrained=True)
6164
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
62-
mobilenet = models.mobilenet_v2(pretrained=True)
65+
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
66+
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
6367
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
6468
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
6569
mnasnet = models.mnasnet1_0(pretrained=True)
@@ -137,6 +141,7 @@ Inception v3 22.55 6.44
137141
GoogleNet 30.22 10.47
138142
ShuffleNet V2 30.64 11.68
139143
MobileNet V2 28.12 9.71
144+
MobileNet V3 Large 25.96 8.66
140145
ResNeXt-50-32x4d 22.38 6.30
141146
ResNeXt-101-32x8d 20.69 5.47
142147
Wide ResNet-50-2 21.49 5.91
@@ -153,7 +158,8 @@ MNASNet 1.0 26.49 8.456
153158
.. _Inception: https://arxiv.org/abs/1512.00567
154159
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
155160
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
156-
.. _MobileNet: https://arxiv.org/abs/1801.04381
161+
.. _MobileNetV2: https://arxiv.org/abs/1801.04381
162+
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
157163
.. _ResNeXt: https://arxiv.org/abs/1611.05431
158164
.. _MNASNet: https://arxiv.org/abs/1807.11626
159165

@@ -231,6 +237,12 @@ MobileNet v2
231237

232238
.. autofunction:: mobilenet_v2
233239

240+
MobileNet v3
241+
-------------
242+
243+
.. autofunction:: mobilenet_v3_large
244+
.. autofunction:: mobilenet_v3_small
245+
234246
ResNext
235247
-------
236248

hubconf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
1212
from torchvision.models.googlenet import googlenet
1313
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
14-
from torchvision.models.mobilenet import mobilenet_v2
14+
from torchvision.models.mobilenetv2 import mobilenet_v2
15+
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
1516
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
1617
mnasnet1_3
1718

references/classification/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
5353
--lr-step-size 1 --lr-gamma 0.98
5454
```
5555

56+
57+
### MobileNetV3 Large
58+
```
59+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
60+
--model mobilenet_v3_large --epochs 600 --opt rmsprop --batch-size 128 --lr 0.064\
61+
--wd 0.00001 --lr-step-size 2 --lr-gamma 0.973 --auto-augment imagenet --random-erase 0.2
62+
```
63+
64+
Then we averaged the parameters of the last 3 checkpoints that improved the Acc@1. See [#3182](https://github.com/pytorch/vision/pull/3182) for details.
65+
5666
## Mixed precision training
5767
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).
5868

references/classification/train.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _get_cache_path(filepath):
7979
return cache_path
8080

8181

82-
def load_data(traindir, valdir, cache_dataset, distributed):
82+
def load_data(traindir, valdir, args):
8383
# Data loading code
8484
print("Loading data")
8585
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
8888
print("Loading training data")
8989
st = time.time()
9090
cache_path = _get_cache_path(traindir)
91-
if cache_dataset and os.path.exists(cache_path):
91+
if args.cache_dataset and os.path.exists(cache_path):
9292
# Attention, as the transforms are also cached!
9393
print("Loading dataset_train from {}".format(cache_path))
9494
dataset, _ = torch.load(cache_path)
9595
else:
96+
trans = [
97+
transforms.RandomResizedCrop(224),
98+
transforms.RandomHorizontalFlip(),
99+
]
100+
if args.auto_augment is not None:
101+
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
102+
trans.append(transforms.AutoAugment(policy=aa_policy))
103+
trans.extend([
104+
transforms.ToTensor(),
105+
normalize,
106+
])
107+
if args.random_erase > 0:
108+
trans.append(transforms.RandomErasing(p=args.random_erase))
96109
dataset = torchvision.datasets.ImageFolder(
97110
traindir,
98-
transforms.Compose([
99-
transforms.RandomResizedCrop(224),
100-
transforms.RandomHorizontalFlip(),
101-
transforms.ToTensor(),
102-
normalize,
103-
]))
104-
if cache_dataset:
111+
transforms.Compose(trans))
112+
if args.cache_dataset:
105113
print("Saving dataset_train to {}".format(cache_path))
106114
utils.mkdir(os.path.dirname(cache_path))
107115
utils.save_on_master((dataset, traindir), cache_path)
108116
print("Took", time.time() - st)
109117

110118
print("Loading validation data")
111119
cache_path = _get_cache_path(valdir)
112-
if cache_dataset and os.path.exists(cache_path):
120+
if args.cache_dataset and os.path.exists(cache_path):
113121
# Attention, as the transforms are also cached!
114122
print("Loading dataset_test from {}".format(cache_path))
115123
dataset_test, _ = torch.load(cache_path)
@@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
122130
transforms.ToTensor(),
123131
normalize,
124132
]))
125-
if cache_dataset:
133+
if args.cache_dataset:
126134
print("Saving dataset_test to {}".format(cache_path))
127135
utils.mkdir(os.path.dirname(cache_path))
128136
utils.save_on_master((dataset_test, valdir), cache_path)
129137

130138
print("Creating data loaders")
131-
if distributed:
139+
if args.distributed:
132140
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
133141
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
134142
else:
@@ -155,8 +163,7 @@ def main(args):
155163

156164
train_dir = os.path.join(args.data_path, 'train')
157165
val_dir = os.path.join(args.data_path, 'val')
158-
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
159-
args.cache_dataset, args.distributed)
166+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
160167
data_loader = torch.utils.data.DataLoader(
161168
dataset, batch_size=args.batch_size,
162169
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
@@ -173,8 +180,15 @@ def main(args):
173180

174181
criterion = nn.CrossEntropyLoss()
175182

176-
optimizer = torch.optim.SGD(
177-
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
183+
opt_name = args.opt.lower()
184+
if opt_name == 'sgd':
185+
optimizer = torch.optim.SGD(
186+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
187+
elif opt_name == 'rmsprop':
188+
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
189+
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
190+
else:
191+
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
178192

179193
if args.apex:
180194
model, optimizer = amp.initialize(model, optimizer,
@@ -238,6 +252,7 @@ def parse_args():
238252
help='number of total epochs to run')
239253
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
240254
help='number of data loading workers (default: 16)')
255+
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
241256
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
242257
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
243258
help='momentum')
@@ -275,6 +290,8 @@ def parse_args():
275290
help="Use pre-trained models from the modelzoo",
276291
action="store_true",
277292
)
293+
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
294+
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
278295

279296
# Mixed precision training parameters
280297
parser.add_argument('--apex', action='store_true',
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,17 @@ def test_mobilenet_v2_residual_setting(self):
275275
out = model(x)
276276
self.assertEqual(out.shape[-1], 1000)
277277

278-
def test_mobilenetv2_norm_layer(self):
279-
model = models.__dict__["mobilenet_v2"]()
280-
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
278+
def test_mobilenet_norm_layer(self):
279+
for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]:
280+
model = models.__dict__[name]()
281+
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
281282

282-
def get_gn(num_channels):
283-
return nn.GroupNorm(32, num_channels)
283+
def get_gn(num_channels):
284+
return nn.GroupNorm(32, num_channels)
284285

285-
model = models.__dict__["mobilenet_v2"](norm_layer=get_gn)
286-
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
287-
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
286+
model = models.__dict__[name](norm_layer=get_gn)
287+
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
288+
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
288289

289290
def test_inception_v3_eval(self):
290291
# replacement for models.inception_v3(pretrained=True) that does not download weights

torchvision/models/mobilenet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all
2+
from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all
23

3-
__all__ = mv2_all
4+
__all__ = mv2_all + mv3_all

torchvision/models/mobilenetv2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
norm_layer(out_planes),
5454
activation_layer(inplace=True)
5555
)
56+
self.out_channels = out_planes
5657

5758

5859
# necessary for backwards compatibility
@@ -90,6 +91,8 @@ def __init__(
9091
norm_layer(oup),
9192
])
9293
self.conv = nn.Sequential(*layers)
94+
self.out_channels = oup
95+
self.is_strided = stride > 1
9396

9497
def forward(self, x: Tensor) -> Tensor:
9598
if self.use_res_connect:

0 commit comments

Comments
 (0)