Skip to content

Commit 9a758a8

Browse files
committed
Code review nits.
1 parent 5030435 commit 9a758a8

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import partial
44
from torch import nn, Tensor
55
from torch.nn import functional as F
6-
from typing import Any, Callable, List, Optional
6+
from typing import Any, Callable, List, Optional, Sequence
77

88
from torchvision.models.utils import load_state_dict_from_url
99
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
@@ -14,8 +14,8 @@
1414

1515
# TODO: add pretrained
1616
model_urls = {
17-
"mobilenet_v3_large_1_0": None,
18-
"mobilenet_v3_small_1_0": None,
17+
"mobilenet_v3_large": None,
18+
"mobilenet_v3_small": None,
1919
}
2020

2121

@@ -67,7 +67,8 @@ class InvertedResidual(nn.Module):
6767

6868
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]):
6969
super().__init__()
70-
assert cnf.stride in [1, 2]
70+
if not (1 <= cnf.stride <= 2):
71+
raise ValueError('illegal stride value')
7172

7273
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels
7374

@@ -123,6 +124,9 @@ def __init__(
123124

124125
if not inverted_residual_setting:
125126
raise ValueError("The inverted_residual_setting should not be empty")
127+
elif not (isinstance(inverted_residual_setting, Sequence) and
128+
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
129+
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
126130

127131
if block is None:
128132
block = InvertedResidual
@@ -229,8 +233,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
229233
]
230234
last_channel = adjust_channels(1280)
231235

232-
return _mobilenet_v3("mobilenet_v3_large_1_0", inverted_residual_setting, last_channel, pretrained, progress,
233-
**kwargs)
236+
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
234237

235238

236239
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
@@ -261,5 +264,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
261264
]
262265
last_channel = adjust_channels(1024)
263266

264-
return _mobilenet_v3("mobilenet_v3_small_1_0", inverted_residual_setting, last_channel, pretrained, progress,
265-
**kwargs)
267+
return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)

0 commit comments

Comments
 (0)