Skip to content

Commit 850f5f3

Browse files
committed
construct model with stem, block, classifier instances
1 parent 4849619 commit 850f5f3

File tree

1 file changed

+85
-69
lines changed

1 file changed

+85
-69
lines changed

torchvision/models/regnet.py

Lines changed: 85 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,20 @@ def __init__(
7373
norm_layer=norm_layer, activation_layer=activation_layer)
7474

7575

76+
def _make_stem(
77+
stem_width: int,
78+
norm_layer: Callable[..., nn.Module],
79+
activation: Callable[..., nn.Module],
80+
stem_type: Callable[..., nn.Module] = SimpleStemIN,
81+
) -> nn.Module:
82+
return stem_type(
83+
3, # width_in
84+
stem_width,
85+
norm_layer,
86+
activation,
87+
)
88+
89+
7690
class VanillaBlock(nn.Sequential):
7791
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
7892

@@ -201,9 +215,6 @@ def __init__(
201215
)
202216
self.activation = activation_layer(inplace=True)
203217

204-
# The projection and transform happen in parallel,
205-
# and activation is not counted with respect to depth
206-
207218
def forward(self, x: Tensor) -> Tensor:
208219
if self.proj_block:
209220
x = self.bn(self.proj(x)) + self.f(x)
@@ -288,6 +299,7 @@ def __init__(
288299
bottleneck_multiplier: float = 1.0,
289300
use_se: bool = True,
290301
se_ratio: float = 0.25,
302+
**kwargs: Any,
291303
) -> None:
292304
if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
293305
raise ValueError("Invalid RegNet settings")
@@ -377,83 +389,79 @@ def _adjust_widths_groups_compatibilty(
377389
return stage_widths, group_widths_min
378390

379391

380-
class RegNet(nn.Module):
381-
def __init__(
382-
self,
383-
block_params: BlockParams,
384-
num_classes: int = 1000,
385-
stem_width: int = 32,
386-
stem_type: Optional[Callable[..., nn.Module]] = None,
387-
block_type: Optional[Callable[..., nn.Module]] = None,
388-
norm_layer: Optional[Callable[..., nn.Module]] = None,
389-
activation: Optional[Callable[..., nn.Module]] = None,
390-
) -> None:
391-
super().__init__()
392-
393-
if stem_type is None:
394-
stem_type = SimpleStemIN
395-
if norm_layer is None:
396-
norm_layer = nn.BatchNorm2d
397-
if block_type is None:
398-
block_type = ResBottleneckBlock
399-
if activation is None:
400-
activation = nn.ReLU
401-
402-
# Ad hoc stem
403-
self.stem = stem_type(
404-
3, # width_in
405-
stem_width,
406-
norm_layer,
407-
activation,
392+
def _make_blocks(
393+
stem_width: int,
394+
params: BlockParams,
395+
norm_layer: Callable[..., nn.Module],
396+
activation: Callable[..., nn.Module],
397+
block_type: Callable[..., nn.Module] = ResBottleneckBlock,
398+
) -> Tuple[nn.Sequential, int]:
399+
current_width = stem_width
400+
401+
blocks = []
402+
for i, (
403+
width_out,
404+
stride,
405+
depth,
406+
group_width,
407+
bottleneck_multiplier,
408+
) in enumerate(params.get_expanded_params()):
409+
blocks.append(
410+
(
411+
f"block{i+1}",
412+
AnyStage(
413+
current_width,
414+
width_out,
415+
stride,
416+
depth,
417+
block_type,
418+
norm_layer,
419+
activation,
420+
group_width,
421+
bottleneck_multiplier,
422+
params.se_ratio,
423+
stage_index=i + 1,
424+
),
425+
)
408426
)
409427

410-
current_width = stem_width
428+
current_width = width_out
429+
return (nn.Sequential(OrderedDict(blocks)), current_width)
411430

412-
blocks = []
413-
for i, (
414-
width_out,
415-
stride,
416-
depth,
417-
group_width,
418-
bottleneck_multiplier,
419-
) in enumerate(block_params.get_expanded_params()):
420-
blocks.append(
421-
(
422-
f"block{i+1}",
423-
AnyStage(
424-
current_width,
425-
width_out,
426-
stride,
427-
depth,
428-
block_type,
429-
norm_layer,
430-
activation,
431-
group_width,
432-
bottleneck_multiplier,
433-
block_params.se_ratio,
434-
stage_index=i + 1,
435-
),
436-
)
437-
)
438431

439-
current_width = width_out
432+
class Classifier(nn.Module):
433+
def __init__(self, in_channels: int, num_classes: int = 1000) -> None:
434+
super().__init__()
435+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
436+
self.fc = nn.Linear(in_features=in_channels, out_features=num_classes)
440437

441-
self.trunk_output = nn.Sequential(OrderedDict(blocks))
438+
def forward(self, x: Tensor) -> Tensor:
439+
x = self.avgpool(x)
440+
x = x.flatten(start_dim=1)
441+
x = self.fc(x)
442+
return x
442443

443-
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
444-
self.fc = nn.Linear(in_features=current_width, out_features=num_classes)
444+
445+
class RegNet(nn.Module):
446+
def __init__(
447+
self,
448+
stem: nn.Module,
449+
blocks: nn.Module,
450+
classifier: nn.Module,
451+
**kwargs: Any,
452+
) -> None:
453+
super().__init__()
454+
self.stem = stem
455+
self.blocks = blocks
456+
self.classifier = classifier
445457

446458
# Init weights and good to go
447459
self.reset_parameters()
448460

449461
def forward(self, x: Tensor) -> Tensor:
450462
x = self.stem(x)
451-
x = self.trunk_output(x)
452-
453-
x = self.avgpool(x)
454-
x = x.flatten(start_dim=1)
455-
x = self.fc(x)
456-
463+
x = self.blocks(x)
464+
x = self.classifier(x)
457465
return x
458466

459467
def reset_parameters(self) -> None:
@@ -472,7 +480,15 @@ def reset_parameters(self) -> None:
472480

473481

474482
def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet:
475-
model = RegNet(block_params, norm_layer=partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1), **kwargs)
483+
norm_layer = kwargs["norm_layer"] if "norm_layer" in kwargs else partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)
484+
activation = kwargs["activation"] if "activation" in kwargs else nn.ReLU
485+
num_classes = kwargs["num_classes"] if "num_classes" in kwargs else 1000
486+
487+
stem_width = 32
488+
stem = _make_stem(stem_width, norm_layer=norm_layer, activation=activation)
489+
blocks, out_channels = _make_blocks(stem_width, params=block_params, norm_layer=norm_layer, activation=activation)
490+
classifier = Classifier(out_channels, num_classes)
491+
model = RegNet(stem, blocks, classifier)
476492
if pretrained:
477493
if arch not in model_urls:
478494
raise ValueError(f"No checkpoint is available for model type {arch}")

0 commit comments

Comments
 (0)