Skip to content

Commit 8a83cf2

Browse files
kazhangdatumbox
andauthored
Allow custom activation in SqueezeExcitation of EfficientNet (#4448)
* allow custom activation in SqueezeExcitation * use ReLU as the default activation * make scale activation parameterizable Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 2e0949e commit 8a83cf2

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torchvision/models/efficientnet.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,25 @@
3232

3333

3434
class SqueezeExcitation(nn.Module):
35-
def __init__(self, input_channels: int, squeeze_channels: int):
35+
def __init__(
36+
self,
37+
input_channels: int,
38+
squeeze_channels: int,
39+
activation: Callable[..., nn.Module] = nn.ReLU,
40+
scale_activation: Callable[..., nn.Module] = nn.Sigmoid,
41+
) -> None:
3642
super().__init__()
3743
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
3844
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
45+
self.activation = activation()
46+
self.scale_activation = scale_activation()
3947

4048
def _scale(self, input: Tensor) -> Tensor:
4149
scale = F.adaptive_avg_pool2d(input, 1)
4250
scale = self.fc1(scale)
43-
scale = F.silu(scale, inplace=True)
51+
scale = self.activation(scale)
4452
scale = self.fc2(scale)
45-
return scale.sigmoid()
53+
return self.scale_activation(scale)
4654

4755
def forward(self, input: Tensor) -> Tensor:
4856
scale = self._scale(input)
@@ -108,7 +116,7 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer:
108116

109117
# squeeze and excitation
110118
squeeze_channels = max(1, cnf.input_channels // 4)
111-
layers.append(se_layer(expanded_channels, squeeze_channels))
119+
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
112120

113121
# project
114122
layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,

0 commit comments

Comments
 (0)