Skip to content

Commit 37a9ee5

Browse files
authored
Add EfficientNet Architecture in TorchVision (#4293)
* Adding code skeleton * Adding MBConvConfig. * Extend SqueezeExcitation to support custom min_value and activation. * Implement MBConv. * Replace stochastic_depth with operator. * Adding the rest of the EfficientNet implementation * Update torchvision/models/efficientnet.py * Replacing 1st activation of SE with SiLU. * Adding efficientnet_b3. * Replace mobilenetv3 assets with custom. * Switch to standard sigmoid and reconfiguring BN. * Reconfiguration of efficientnet. * Add repr * Add weights. * Update weights. * Adding B5-B7 weights. * Update docs and hubconf. * Fix doc link. * Fix typo on comment.
1 parent d004d77 commit 37a9ee5

16 files changed

+441
-7
lines changed

docs/source/models.rst

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ architectures for image classification:
2727
- `ResNeXt`_
2828
- `Wide ResNet`_
2929
- `MNASNet`_
30+
- `EfficientNet`_
3031

3132
You can construct a model with random weights by calling its constructor:
3233

@@ -47,6 +48,14 @@ You can construct a model with random weights by calling its constructor:
4748
resnext50_32x4d = models.resnext50_32x4d()
4849
wide_resnet50_2 = models.wide_resnet50_2()
4950
mnasnet = models.mnasnet1_0()
51+
efficientnet_b0 = models.efficientnet_b0()
52+
efficientnet_b1 = models.efficientnet_b1()
53+
efficientnet_b2 = models.efficientnet_b2()
54+
efficientnet_b3 = models.efficientnet_b3()
55+
efficientnet_b4 = models.efficientnet_b4()
56+
efficientnet_b5 = models.efficientnet_b5()
57+
efficientnet_b6 = models.efficientnet_b6()
58+
efficientnet_b7 = models.efficientnet_b7()
5059
5160
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
5261
These can be constructed by passing ``pretrained=True``:
@@ -68,6 +77,14 @@ These can be constructed by passing ``pretrained=True``:
6877
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
6978
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
7079
mnasnet = models.mnasnet1_0(pretrained=True)
80+
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
81+
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
82+
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
83+
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
84+
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
85+
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
86+
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
87+
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
7188
7289
Instancing a pre-trained model will download its weights to a cache directory.
7390
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
@@ -113,7 +130,10 @@ Unfortunately, the concrete `subset` that was used is lost. For more
113130
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
114131
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.
115132

116-
ImageNet 1-crop error rates (224x224)
133+
The sizes of the EfficientNet models depend on the variant. For the exact input sizes
134+
`check here <https://github.com/pytorch/vision/blob/d2bfd639e46e1c5dc3c177f889dc7750c8d137c7/references/classification/train.py#L92-L93>`_
135+
136+
ImageNet 1-crop error rates
117137

118138
================================ ============= =============
119139
Model Acc@1 Acc@5
@@ -151,6 +171,14 @@ Wide ResNet-50-2 78.468 94.086
151171
Wide ResNet-101-2 78.848 94.284
152172
MNASNet 1.0 73.456 91.510
153173
MNASNet 0.5 67.734 87.490
174+
EfficientNet-B0 77.692 93.532
175+
EfficientNet-B1 78.642 94.186
176+
EfficientNet-B2 80.608 95.310
177+
EfficientNet-B3 82.008 96.054
178+
EfficientNet-B4 83.384 96.594
179+
EfficientNet-B5 83.444 96.628
180+
EfficientNet-B6 84.008 96.916
181+
EfficientNet-B7 84.122 96.908
154182
================================ ============= =============
155183

156184

@@ -166,6 +194,7 @@ MNASNet 0.5 67.734 87.490
166194
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
167195
.. _ResNeXt: https://arxiv.org/abs/1611.05431
168196
.. _MNASNet: https://arxiv.org/abs/1807.11626
197+
.. _EfficientNet: https://arxiv.org/abs/1905.11946
169198

170199
.. currentmodule:: torchvision.models
171200

@@ -267,6 +296,18 @@ MNASNet
267296
.. autofunction:: mnasnet1_0
268297
.. autofunction:: mnasnet1_3
269298

299+
EfficientNet
300+
------------
301+
302+
.. autofunction:: efficientnet_b0
303+
.. autofunction:: efficientnet_b1
304+
.. autofunction:: efficientnet_b2
305+
.. autofunction:: efficientnet_b3
306+
.. autofunction:: efficientnet_b4
307+
.. autofunction:: efficientnet_b5
308+
.. autofunction:: efficientnet_b6
309+
.. autofunction:: efficientnet_b7
310+
270311
Quantized Models
271312
----------------
272313

hubconf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
1616
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
1717
mnasnet1_3
18+
from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \
19+
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
1820

1921
# segmentation
2022
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \

references/classification/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@
6868
and [#3354](https://github.com/pytorch/vision/pull/3354) for details.
6969

7070

71+
### EfficientNet
72+
73+
The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108).
74+
75+
The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).
76+
7177
## Mixed precision training
7278
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).
7379

references/classification/presets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torchvision.transforms import autoaugment, transforms
2+
from torchvision.transforms.functional import InterpolationMode
23

34

45
class ClassificationPresetTrain:
@@ -24,10 +25,11 @@ def __call__(self, img):
2425

2526

2627
class ClassificationPresetEval:
27-
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
28+
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
29+
interpolation=InterpolationMode.BILINEAR):
2830

2931
self.transforms = transforms.Compose([
30-
transforms.Resize(resize_size),
32+
transforms.Resize(resize_size, interpolation=interpolation),
3133
transforms.CenterCrop(crop_size),
3234
transforms.ToTensor(),
3335
transforms.Normalize(mean=mean, std=std),

references/classification/train.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.utils.data
77
from torch import nn
88
import torchvision
9+
from torchvision.transforms.functional import InterpolationMode
910

1011
import presets
1112
import utils
@@ -82,7 +83,18 @@ def _get_cache_path(filepath):
8283
def load_data(traindir, valdir, args):
8384
# Data loading code
8485
print("Loading data")
85-
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
86+
resize_size, crop_size = 256, 224
87+
interpolation = InterpolationMode.BILINEAR
88+
if args.model == 'inception_v3':
89+
resize_size, crop_size = 342, 299
90+
elif args.model.startswith('efficientnet_'):
91+
sizes = {
92+
'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300),
93+
'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600),
94+
}
95+
e_type = args.model.replace('efficientnet_', '')
96+
resize_size, crop_size = sizes[e_type]
97+
interpolation = InterpolationMode.BICUBIC
8698

8799
print("Loading training data")
88100
st = time.time()
@@ -113,7 +125,8 @@ def load_data(traindir, valdir, args):
113125
else:
114126
dataset_test = torchvision.datasets.ImageFolder(
115127
valdir,
116-
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
128+
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size,
129+
interpolation=interpolation))
117130
if args.cache_dataset:
118131
print("Saving dataset_test to {}".format(cache_path))
119132
utils.mkdir(os.path.dirname(cache_path))
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)