Skip to content

2499 adds 2d/3d support of patchembedding, ViT, and unetr model #2698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
super().__init__()

if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
raise ValueError("dropout_rate should be between 0 and 1.")

self.linear1 = nn.Linear(hidden_size, mlp_dim)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
Expand Down
73 changes: 39 additions & 34 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,42 @@


import math
from typing import Tuple, Union
from typing import Sequence, Union

import numpy as np
import torch
import torch.nn as nn

from monai.utils import optional_import
from monai.networks.layers import Conv
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"}


class PatchEmbeddingBlock(nn.Module):
"""
A patch embedding block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

Example::

>>> from monai.networks.blocks import PatchEmbeddingBlock
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")

"""

def __init__(
self,
in_channels: int,
img_size: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
img_size: Union[Sequence[int], int],
patch_size: Union[Sequence[int], int],
hidden_size: int,
num_heads: int,
pos_embed: str,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
) -> None:
"""
Args:
Expand All @@ -46,47 +57,44 @@ def __init__(
num_heads: number of attention heads.
pos_embed: position embedding layer type.
dropout_rate: faction of the input units to drop.
spatial_dims: number of spatial dimensions.


"""

super().__init__()
super(PatchEmbeddingBlock, self).__init__()

if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")
raise ValueError("hidden size should be divisible by num_heads.")

self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_size = ensure_tuple_rep(patch_size, spatial_dims)
for m, p in zip(img_size, patch_size):
if m < p:
raise AssertionError("patch_size should be smaller than img_size.")
raise ValueError("patch_size should be smaller than img_size.")
if self.pos_embed == "perceptron" and m % p != 0:
raise ValueError("patch_size should be divisible by img_size for perceptron.")
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
self.patch_dim = in_channels * np.prod(patch_size)

if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")

if pos_embed == "perceptron":
if img_size[0] % patch_size[0] != 0:
raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")

self.n_patches = (
(img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
)
self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2]

self.pos_embed = pos_embed
self.patch_embeddings: Union[nn.Conv3d, nn.Sequential]
self.patch_embeddings: nn.Module
if self.pos_embed == "conv":
self.patch_embeddings = nn.Conv3d(
self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.pos_embed == "perceptron":
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
self.patch_embeddings = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
p1=patch_size[0],
p2=patch_size[1],
p3=patch_size[2],
),
Rearrange(f"{from_chars} -> {to_chars}", **axes_len),
nn.Linear(self.patch_dim, hidden_size),
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
Expand Down Expand Up @@ -121,12 +129,9 @@ def norm_cdf(x):
return tensor

def forward(self, x):
x = self.patch_embeddings(x)
if self.pos_embed == "conv":
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
elif self.pos_embed == "perceptron":
x = self.patch_embeddings(x)
x = x.flatten(2).transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
16 changes: 6 additions & 10 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from monai.utils import optional_import

einops, has_einops = optional_import("einops")
einops, _ = optional_import("einops")


class SABlock(nn.Module):
Expand All @@ -37,13 +37,13 @@ def __init__(

"""

super().__init__()
super(SABlock, self).__init__()

if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")
raise ValueError("hidden size should be divisible by num_heads.")

self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
Expand All @@ -52,17 +52,13 @@ def __init__(
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
if has_einops:
self.rearrange = einops.rearrange
else:
raise ValueError('"Requires einops.')

def forward(self, x):
q, k, v = self.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.rearrange(x, "b h l d -> b l (h d)")
x = einops.rearrange(x, "b h l d -> b l (h d)")
x = self.out_proj(x)
x = self.drop_output(x)
return x
4 changes: 2 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init__(
super().__init__()

if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")
raise ValueError("hidden_size should be divisible by num_heads.")

self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
self.norm1 = nn.LayerNorm(hidden_size)
Expand Down
11 changes: 4 additions & 7 deletions monai/networks/blocks/unetr_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int, # type: ignore
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
res_block: bool = False,
Expand All @@ -41,7 +40,6 @@ def __init__(
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
res_block: bool argument to determine if residual block is used.
Expand Down Expand Up @@ -148,7 +146,7 @@ def __init__(
is_transposed=True,
),
UnetResBlock(
spatial_dims=3,
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
Expand All @@ -173,7 +171,7 @@ def __init__(
is_transposed=True,
),
UnetBasicBlock(
spatial_dims=3,
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
Expand Down Expand Up @@ -257,5 +255,4 @@ def __init__(
)

def forward(self, inp):
out = self.layer(inp)
return out
return self.layer(inp)
Loading