@@ -73,6 +73,20 @@ def __init__(
73
73
norm_layer = norm_layer , activation_layer = activation_layer )
74
74
75
75
76
+ def _make_stem (
77
+ stem_width : int ,
78
+ norm_layer : Callable [..., nn .Module ],
79
+ activation : Callable [..., nn .Module ],
80
+ stem_type : Callable [..., nn .Module ] = SimpleStemIN ,
81
+ ) -> nn .Module :
82
+ return stem_type (
83
+ 3 , # width_in
84
+ stem_width ,
85
+ norm_layer ,
86
+ activation ,
87
+ )
88
+
89
+
76
90
class VanillaBlock (nn .Sequential ):
77
91
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
78
92
@@ -201,9 +215,6 @@ def __init__(
201
215
)
202
216
self .activation = activation_layer (inplace = True )
203
217
204
- # The projection and transform happen in parallel,
205
- # and activation is not counted with respect to depth
206
-
207
218
def forward (self , x : Tensor ) -> Tensor :
208
219
if self .proj_block :
209
220
x = self .bn (self .proj (x )) + self .f (x )
@@ -288,6 +299,7 @@ def __init__(
288
299
bottleneck_multiplier : float = 1.0 ,
289
300
use_se : bool = True ,
290
301
se_ratio : float = 0.25 ,
302
+ ** kwargs : Any ,
291
303
) -> None :
292
304
if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0 :
293
305
raise ValueError ("Invalid RegNet settings" )
@@ -377,83 +389,79 @@ def _adjust_widths_groups_compatibilty(
377
389
return stage_widths , group_widths_min
378
390
379
391
380
- class RegNet (nn .Module ):
381
- def __init__ (
382
- self ,
383
- block_params : BlockParams ,
384
- num_classes : int = 1000 ,
385
- stem_width : int = 32 ,
386
- stem_type : Optional [Callable [..., nn .Module ]] = None ,
387
- block_type : Optional [Callable [..., nn .Module ]] = None ,
388
- norm_layer : Optional [Callable [..., nn .Module ]] = None ,
389
- activation : Optional [Callable [..., nn .Module ]] = None ,
390
- ) -> None :
391
- super ().__init__ ()
392
-
393
- if stem_type is None :
394
- stem_type = SimpleStemIN
395
- if norm_layer is None :
396
- norm_layer = nn .BatchNorm2d
397
- if block_type is None :
398
- block_type = ResBottleneckBlock
399
- if activation is None :
400
- activation = nn .ReLU
401
-
402
- # Ad hoc stem
403
- self .stem = stem_type (
404
- 3 , # width_in
405
- stem_width ,
406
- norm_layer ,
407
- activation ,
392
+ def _make_blocks (
393
+ stem_width : int ,
394
+ params : BlockParams ,
395
+ norm_layer : Callable [..., nn .Module ],
396
+ activation : Callable [..., nn .Module ],
397
+ block_type : Callable [..., nn .Module ] = ResBottleneckBlock ,
398
+ ) -> Tuple [nn .Sequential , int ]:
399
+ current_width = stem_width
400
+
401
+ blocks = []
402
+ for i , (
403
+ width_out ,
404
+ stride ,
405
+ depth ,
406
+ group_width ,
407
+ bottleneck_multiplier ,
408
+ ) in enumerate (params .get_expanded_params ()):
409
+ blocks .append (
410
+ (
411
+ f"block{ i + 1 } " ,
412
+ AnyStage (
413
+ current_width ,
414
+ width_out ,
415
+ stride ,
416
+ depth ,
417
+ block_type ,
418
+ norm_layer ,
419
+ activation ,
420
+ group_width ,
421
+ bottleneck_multiplier ,
422
+ params .se_ratio ,
423
+ stage_index = i + 1 ,
424
+ ),
425
+ )
408
426
)
409
427
410
- current_width = stem_width
428
+ current_width = width_out
429
+ return (nn .Sequential (OrderedDict (blocks )), current_width )
411
430
412
- blocks = []
413
- for i , (
414
- width_out ,
415
- stride ,
416
- depth ,
417
- group_width ,
418
- bottleneck_multiplier ,
419
- ) in enumerate (block_params .get_expanded_params ()):
420
- blocks .append (
421
- (
422
- f"block{ i + 1 } " ,
423
- AnyStage (
424
- current_width ,
425
- width_out ,
426
- stride ,
427
- depth ,
428
- block_type ,
429
- norm_layer ,
430
- activation ,
431
- group_width ,
432
- bottleneck_multiplier ,
433
- block_params .se_ratio ,
434
- stage_index = i + 1 ,
435
- ),
436
- )
437
- )
438
431
439
- current_width = width_out
432
+ class Classifier (nn .Module ):
433
+ def __init__ (self , in_channels : int , num_classes : int = 1000 ) -> None :
434
+ super ().__init__ ()
435
+ self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
436
+ self .fc = nn .Linear (in_features = in_channels , out_features = num_classes )
440
437
441
- self .trunk_output = nn .Sequential (OrderedDict (blocks ))
438
+ def forward (self , x : Tensor ) -> Tensor :
439
+ x = self .avgpool (x )
440
+ x = x .flatten (start_dim = 1 )
441
+ x = self .fc (x )
442
+ return x
442
443
443
- self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
444
- self .fc = nn .Linear (in_features = current_width , out_features = num_classes )
444
+
445
+ class RegNet (nn .Module ):
446
+ def __init__ (
447
+ self ,
448
+ stem : nn .Module ,
449
+ blocks : nn .Module ,
450
+ classifier : nn .Module ,
451
+ ** kwargs : Any ,
452
+ ) -> None :
453
+ super ().__init__ ()
454
+ self .stem = stem
455
+ self .blocks = blocks
456
+ self .classifier = classifier
445
457
446
458
# Init weights and good to go
447
459
self .reset_parameters ()
448
460
449
461
def forward (self , x : Tensor ) -> Tensor :
450
462
x = self .stem (x )
451
- x = self .trunk_output (x )
452
-
453
- x = self .avgpool (x )
454
- x = x .flatten (start_dim = 1 )
455
- x = self .fc (x )
456
-
463
+ x = self .blocks (x )
464
+ x = self .classifier (x )
457
465
return x
458
466
459
467
def reset_parameters (self ) -> None :
@@ -472,7 +480,15 @@ def reset_parameters(self) -> None:
472
480
473
481
474
482
def _regnet (arch : str , block_params : BlockParams , pretrained : bool , progress : bool , ** kwargs : Any ) -> RegNet :
475
- model = RegNet (block_params , norm_layer = partial (nn .BatchNorm2d , eps = 1e-05 , momentum = 0.1 ), ** kwargs )
483
+ norm_layer = kwargs ["norm_layer" ] if "norm_layer" in kwargs else partial (nn .BatchNorm2d , eps = 1e-05 , momentum = 0.1 )
484
+ activation = kwargs ["activation" ] if "activation" in kwargs else nn .ReLU
485
+ num_classes = kwargs ["num_classes" ] if "num_classes" in kwargs else 1000
486
+
487
+ stem_width = 32
488
+ stem = _make_stem (stem_width , norm_layer = norm_layer , activation = activation )
489
+ blocks , out_channels = _make_blocks (stem_width , params = block_params , norm_layer = norm_layer , activation = activation )
490
+ classifier = Classifier (out_channels , num_classes )
491
+ model = RegNet (stem , blocks , classifier )
476
492
if pretrained :
477
493
if arch not in model_urls :
478
494
raise ValueError (f"No checkpoint is available for model type { arch } " )
0 commit comments