Skip to content

2812 enhance swish and efficientnet #2813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ Nets
.. autoclass:: EfficientNetBN
:members:

`EfficientNetBNFeatures`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: EfficientNetBNFeatures
:members:

`SegResNet`
~~~~~~~~~~~
.. autoclass:: SegResNet
Expand Down
37 changes: 32 additions & 5 deletions monai/networks/blocks/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,28 @@

if optional_import("torch.nn.functional", name="mish")[1]:

def monai_mish(x):
return torch.nn.functional.mish(x, inplace=True)
def monai_mish(x, inplace: bool = False):
return torch.nn.functional.mish(x, inplace=inplace)


else:

def monai_mish(x):
def monai_mish(x, inplace: bool = False):
return x * torch.tanh(torch.nn.functional.softplus(x))


if optional_import("torch.nn.functional", name="silu")[1]:

def monai_swish(x, inplace: bool = False):
return torch.nn.functional.silu(x, inplace=inplace)


else:

def monai_swish(x, inplace: bool = False):
return SwishImplementation.apply(x)


class Swish(nn.Module):
r"""Applies the element-wise function:

Expand Down Expand Up @@ -92,6 +104,9 @@ class MemoryEfficientSwish(nn.Module):

Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.

From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented,
this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version.

Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
Expand All @@ -107,8 +122,13 @@ class MemoryEfficientSwish(nn.Module):
>>> output = m(input)
"""

def __init__(self, inplace: bool = False):
super(MemoryEfficientSwish, self).__init__()
# inplace only works when using torch.nn.functional.silu
self.inplace = inplace

def forward(self, input: torch.Tensor):
return SwishImplementation.apply(input)
return monai_swish(input, self.inplace)


class Mish(nn.Module):
Expand All @@ -119,6 +139,8 @@ class Mish(nn.Module):

Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.

From Pytorch 1.9.0+, the optimized version of `Mish` is implemented,
this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version.

Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
Expand All @@ -135,5 +157,10 @@ class Mish(nn.Module):
>>> output = m(input)
"""

def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
# inplace only works when using torch.nn.functional.mish
self.inplace = inplace

def forward(self, input: torch.Tensor):
return monai_mish(input)
return monai_mish(input, self.inplace)
9 changes: 8 additions & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
densenet264,
)
from .dynunet import DynUNet, DynUnet, Dynunet, dynunet
from .efficientnet import BlockArgs, EfficientNet, EfficientNetBN, drop_connect, get_efficientnet_image_size
from .efficientnet import (
BlockArgs,
EfficientNet,
EfficientNetBN,
EfficientNetBNFeatures,
drop_connect,
get_efficientnet_image_size,
)
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
Expand Down
9 changes: 5 additions & 4 deletions monai/networks/nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from monai.networks.layers.factories import Conv, Dropout, Pool
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils.module import look_up_option

__all__ = [
"DenseNet",
Expand Down Expand Up @@ -249,7 +250,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def _load_state_dict(model, arch, progress):
def _load_state_dict(model: nn.Module, arch: str, progress: bool):
"""
This function is used to load pretrained models.
Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.
Expand All @@ -260,12 +261,12 @@ def _load_state_dict(model, arch, progress):
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
}
if arch in model_urls:
model_url = model_urls[arch]
else:
model_url = look_up_option(arch, model_urls, None)
if model_url is None:
raise ValueError(
"only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights."
)

pattern = re.compile(
r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
Expand Down
Loading