Skip to content

Commit 9055cf1

Browse files
jdsgomescyyever
authored andcommitted
Multiweight DenseNet prototype models (pytorch#4680)
* Densenet121 added * All densenet prototypes added * fixing flake8 errors * fixing argument type
1 parent 2887360 commit 9055cf1

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .alexnet import *
22
from .resnet import *
3+
from .densenet import *
34
from .vgg import *
45
from . import detection
56
from . import quantization
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import re
2+
import warnings
3+
from functools import partial
4+
from typing import Any, Optional, Tuple
5+
6+
import torch.nn as nn
7+
8+
from ...models.densenet import DenseNet
9+
from ..transforms.presets import ImageNetEval
10+
from ._api import Weights, WeightEntry
11+
from ._meta import _IMAGENET_CATEGORIES
12+
13+
14+
__all__ = [
15+
"DenseNet",
16+
"DenseNet121Weights",
17+
"DenseNet161Weights",
18+
"DenseNet169Weights",
19+
"DenseNet201Weights",
20+
"densenet121",
21+
"densenet161",
22+
"densenet169",
23+
"densenet201",
24+
]
25+
26+
27+
def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None:
28+
# '.'s are no longer allowed in module names, but previous _DenseLayer
29+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
30+
# They are also in the checkpoints in model_urls. This pattern is used
31+
# to find such keys.
32+
pattern = re.compile(
33+
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
34+
)
35+
36+
state_dict = weights.state_dict(progress=progress)
37+
for key in list(state_dict.keys()):
38+
res = pattern.match(key)
39+
if res:
40+
new_key = res.group(1) + res.group(2)
41+
state_dict[new_key] = state_dict[key]
42+
del state_dict[key]
43+
model.load_state_dict(state_dict)
44+
45+
46+
def _densenet(
47+
growth_rate: int,
48+
block_config: Tuple[int, int, int, int],
49+
num_init_features: int,
50+
weights: Optional[Weights],
51+
progress: bool,
52+
**kwargs: Any,
53+
) -> DenseNet:
54+
if weights is not None:
55+
kwargs["num_classes"] = len(weights.meta["categories"])
56+
57+
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
58+
59+
if weights is not None:
60+
_load_state_dict(model=model, weights=weights, progress=progress)
61+
62+
return model
63+
64+
65+
_common_meta = {
66+
"size": (224, 224),
67+
"categories": _IMAGENET_CATEGORIES,
68+
}
69+
70+
71+
class DenseNet121Weights(Weights):
72+
ImageNet1K_RefV1 = WeightEntry(
73+
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
74+
transforms=partial(ImageNetEval, crop_size=224),
75+
meta={
76+
**_common_meta,
77+
"recipe": "",
78+
"acc@1": 74.434,
79+
"acc@5": 91.972,
80+
},
81+
)
82+
83+
84+
class DenseNet161Weights(Weights):
85+
ImageNet1K_RefV1 = WeightEntry(
86+
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
87+
transforms=partial(ImageNetEval, crop_size=224),
88+
meta={
89+
**_common_meta,
90+
"recipe": "",
91+
"acc@1": 77.138,
92+
"acc@5": 93.560,
93+
},
94+
)
95+
96+
97+
class DenseNet169Weights(Weights):
98+
ImageNet1K_RefV1 = WeightEntry(
99+
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
100+
transforms=partial(ImageNetEval, crop_size=224),
101+
meta={
102+
**_common_meta,
103+
"recipe": "",
104+
"acc@1": 75.600,
105+
"acc@5": 92.806,
106+
},
107+
)
108+
109+
110+
class DenseNet201Weights(Weights):
111+
ImageNet1K_RefV1 = WeightEntry(
112+
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
113+
transforms=partial(ImageNetEval, crop_size=224),
114+
meta={
115+
**_common_meta,
116+
"recipe": "",
117+
"acc@1": 76.896,
118+
"acc@5": 93.370,
119+
},
120+
)
121+
122+
123+
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
124+
if "pretrained" in kwargs:
125+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
126+
weights = DenseNet121Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
127+
weights = DenseNet121Weights.verify(weights)
128+
129+
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
130+
131+
132+
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
133+
if "pretrained" in kwargs:
134+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
135+
weights = DenseNet161Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
136+
weights = DenseNet161Weights.verify(weights)
137+
138+
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
139+
140+
141+
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
142+
if "pretrained" in kwargs:
143+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
144+
weights = DenseNet169Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
145+
weights = DenseNet169Weights.verify(weights)
146+
147+
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
148+
149+
150+
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
151+
if "pretrained" in kwargs:
152+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
153+
weights = DenseNet201Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
154+
weights = DenseNet201Weights.verify(weights)
155+
156+
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)

0 commit comments

Comments
 (0)