|
3 | 3 | from functools import partial
|
4 | 4 | from torch import nn, Tensor
|
5 | 5 | from torch.nn import functional as F
|
6 |
| -from typing import Any, Callable, List, Optional |
| 6 | +from typing import Any, Callable, List, Optional, Sequence |
7 | 7 |
|
8 | 8 | from torchvision.models.utils import load_state_dict_from_url
|
9 | 9 | from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
|
|
14 | 14 |
|
15 | 15 | # TODO: add pretrained
|
16 | 16 | model_urls = {
|
17 |
| - "mobilenet_v3_large_1_0": None, |
18 |
| - "mobilenet_v3_small_1_0": None, |
| 17 | + "mobilenet_v3_large": None, |
| 18 | + "mobilenet_v3_small": None, |
19 | 19 | }
|
20 | 20 |
|
21 | 21 |
|
@@ -67,7 +67,8 @@ class InvertedResidual(nn.Module):
|
67 | 67 |
|
68 | 68 | def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]):
|
69 | 69 | super().__init__()
|
70 |
| - assert cnf.stride in [1, 2] |
| 70 | + if not (1 <= cnf.stride <= 2): |
| 71 | + raise ValueError('illegal stride value') |
71 | 72 |
|
72 | 73 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels
|
73 | 74 |
|
@@ -123,6 +124,9 @@ def __init__(
|
123 | 124 |
|
124 | 125 | if not inverted_residual_setting:
|
125 | 126 | raise ValueError("The inverted_residual_setting should not be empty")
|
| 127 | + elif not (isinstance(inverted_residual_setting, Sequence) and |
| 128 | + all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): |
| 129 | + raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") |
126 | 130 |
|
127 | 131 | if block is None:
|
128 | 132 | block = InvertedResidual
|
@@ -229,8 +233,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
|
229 | 233 | ]
|
230 | 234 | last_channel = adjust_channels(1280)
|
231 | 235 |
|
232 |
| - return _mobilenet_v3("mobilenet_v3_large_1_0", inverted_residual_setting, last_channel, pretrained, progress, |
233 |
| - **kwargs) |
| 236 | + return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) |
234 | 237 |
|
235 | 238 |
|
236 | 239 | def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
|
@@ -261,5 +264,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
|
261 | 264 | ]
|
262 | 265 | last_channel = adjust_channels(1024)
|
263 | 266 |
|
264 |
| - return _mobilenet_v3("mobilenet_v3_small_1_0", inverted_residual_setting, last_channel, pretrained, progress, |
265 |
| - **kwargs) |
| 267 | + return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) |
0 commit comments