|
| 1 | +import torch |
| 2 | +import torch.fx |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch import nn, Tensor |
| 5 | + |
| 6 | +from ..utils import _log_api_usage_once |
| 7 | + |
| 8 | + |
| 9 | +def drop_block2d( |
| 10 | + input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True |
| 11 | +) -> Tensor: |
| 12 | + """ |
| 13 | + Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks" |
| 14 | + <https://arxiv.org/abs/1810.12890>`. |
| 15 | +
|
| 16 | + Args: |
| 17 | + input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one |
| 18 | + being its batch i.e. a batch with ``N`` rows. |
| 19 | + p (float): Probability of an element to be dropped. |
| 20 | + block_size (int): Size of the block to drop. |
| 21 | + inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. |
| 22 | + eps (float): A value added to the denominator for numerical stability. Default: 1e-6. |
| 23 | + training (bool): apply dropblock if is ``True``. Default: ``True``. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock. |
| 27 | + """ |
| 28 | + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
| 29 | + _log_api_usage_once(drop_block2d) |
| 30 | + if p < 0.0 or p > 1.0: |
| 31 | + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") |
| 32 | + if input.ndim != 4: |
| 33 | + raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.") |
| 34 | + if not training or p == 0.0: |
| 35 | + return input |
| 36 | + |
| 37 | + N, C, H, W = input.size() |
| 38 | + block_size = min(block_size, W, H) |
| 39 | + # compute the gamma of Bernoulli distribution |
| 40 | + gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) |
| 41 | + noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) |
| 42 | + noise.bernoulli_(gamma) |
| 43 | + |
| 44 | + noise = F.pad(noise, [block_size // 2] * 4, value=0) |
| 45 | + noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2) |
| 46 | + noise = 1 - noise |
| 47 | + normalize_scale = noise.numel() / (eps + noise.sum()) |
| 48 | + if inplace: |
| 49 | + input.mul_(noise).mul_(normalize_scale) |
| 50 | + else: |
| 51 | + input = input * noise * normalize_scale |
| 52 | + return input |
| 53 | + |
| 54 | + |
| 55 | +def drop_block3d( |
| 56 | + input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True |
| 57 | +) -> Tensor: |
| 58 | + """ |
| 59 | + Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks" |
| 60 | + <https://arxiv.org/abs/1810.12890>`. |
| 61 | +
|
| 62 | + Args: |
| 63 | + input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one |
| 64 | + being its batch i.e. a batch with ``N`` rows. |
| 65 | + p (float): Probability of an element to be dropped. |
| 66 | + block_size (int): Size of the block to drop. |
| 67 | + inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. |
| 68 | + eps (float): A value added to the denominator for numerical stability. Default: 1e-6. |
| 69 | + training (bool): apply dropblock if is ``True``. Default: ``True``. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock. |
| 73 | + """ |
| 74 | + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
| 75 | + _log_api_usage_once(drop_block3d) |
| 76 | + if p < 0.0 or p > 1.0: |
| 77 | + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") |
| 78 | + if input.ndim != 5: |
| 79 | + raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.") |
| 80 | + if not training or p == 0.0: |
| 81 | + return input |
| 82 | + |
| 83 | + N, C, D, H, W = input.size() |
| 84 | + block_size = min(block_size, D, H, W) |
| 85 | + # compute the gamma of Bernoulli distribution |
| 86 | + gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) |
| 87 | + noise = torch.empty( |
| 88 | + (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device |
| 89 | + ) |
| 90 | + noise.bernoulli_(gamma) |
| 91 | + |
| 92 | + noise = F.pad(noise, [block_size // 2] * 6, value=0) |
| 93 | + noise = F.max_pool3d( |
| 94 | + noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2 |
| 95 | + ) |
| 96 | + noise = 1 - noise |
| 97 | + normalize_scale = noise.numel() / (eps + noise.sum()) |
| 98 | + if inplace: |
| 99 | + input.mul_(noise).mul_(normalize_scale) |
| 100 | + else: |
| 101 | + input = input * noise * normalize_scale |
| 102 | + return input |
| 103 | + |
| 104 | + |
| 105 | +torch.fx.wrap("drop_block2d") |
| 106 | + |
| 107 | + |
| 108 | +class DropBlock2d(nn.Module): |
| 109 | + """ |
| 110 | + See :func:`drop_block2d`. |
| 111 | + """ |
| 112 | + |
| 113 | + def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: |
| 114 | + super().__init__() |
| 115 | + |
| 116 | + self.p = p |
| 117 | + self.block_size = block_size |
| 118 | + self.inplace = inplace |
| 119 | + self.eps = eps |
| 120 | + |
| 121 | + def forward(self, input: Tensor) -> Tensor: |
| 122 | + """ |
| 123 | + Args: |
| 124 | + input (Tensor): Input feature map on which some areas will be randomly |
| 125 | + dropped. |
| 126 | + Returns: |
| 127 | + Tensor: The tensor after DropBlock layer. |
| 128 | + """ |
| 129 | + return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training) |
| 130 | + |
| 131 | + def __repr__(self) -> str: |
| 132 | + s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" |
| 133 | + return s |
| 134 | + |
| 135 | + |
| 136 | +torch.fx.wrap("drop_block3d") |
| 137 | + |
| 138 | + |
| 139 | +class DropBlock3d(DropBlock2d): |
| 140 | + """ |
| 141 | + See :func:`drop_block3d`. |
| 142 | + """ |
| 143 | + |
| 144 | + def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: |
| 145 | + super().__init__(p, block_size, inplace, eps) |
| 146 | + |
| 147 | + def forward(self, input: Tensor) -> Tensor: |
| 148 | + """ |
| 149 | + Args: |
| 150 | + input (Tensor): Input feature map on which some areas will be randomly |
| 151 | + dropped. |
| 152 | + Returns: |
| 153 | + Tensor: The tensor after DropBlock layer. |
| 154 | + """ |
| 155 | + return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training) |
0 commit comments