19
19
}
20
20
21
21
22
- def drop_connect (x : Tensor , rate : float ):
22
+ def drop_connect (x : Tensor , rate : float ) -> Tensor :
23
23
keep = torch .rand (size = (x .size (0 ), ), dtype = x .dtype , device = x .device ) > rate
24
24
keep = keep [(None , ) * (x .ndim - 1 )].T
25
25
return (x / (1.0 - rate )) * keep
@@ -29,7 +29,7 @@ class MBConvConfig:
29
29
def __init__ (self ,
30
30
kernel : int , stride : int , dilation : int ,
31
31
input_channels : int , out_channels : int , expand_ratio : float ,
32
- width_mult : float ):
32
+ width_mult : float ) -> None :
33
33
self .kernel = kernel
34
34
self .stride = stride
35
35
self .dilation = dilation
@@ -38,13 +38,13 @@ def __init__(self,
38
38
self .expanded_channels = self .adjust_channels (input_channels , expand_ratio * width_mult )
39
39
40
40
@staticmethod
41
- def adjust_channels (channels : int , width_mult : float , min_value : Optional [int ] = None ):
41
+ def adjust_channels (channels : int , width_mult : float , min_value : Optional [int ] = None ) -> int :
42
42
return _make_divisible (channels * width_mult , 8 , min_value )
43
43
44
44
45
45
class MBConv (nn .Module ):
46
46
def __init__ (self , cnf : MBConvConfig , norm_layer : Callable [..., nn .Module ],
47
- se_layer : Callable [..., nn .Module ] = SqueezeExcitation ):
47
+ se_layer : Callable [..., nn .Module ] = SqueezeExcitation ) -> None :
48
48
super ().__init__ ()
49
49
if not (1 <= cnf .stride <= 2 ):
50
50
raise ValueError ('illegal stride value' )
0 commit comments