Skip to content

Commit 5568744

Browse files
xiaohu2015datumbox
andauthored
New Feature: add DropBlock layer (#5416)
* Create dropblock.py * add dropblock2d * fix pylint * refactor dropblock * add dropblock * Rename dropblock.py to drop_block.py * fix pylint * add dropblock * add dropblock3d * add drop_block3d * add dropblock * Update drop_block.py * Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis <[email protected]> * Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis <[email protected]> * Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis <[email protected]> * Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis <[email protected]> * Update drop_block.py * Update drop_block.py * import torch.fx * fix lint * fix lint * Update drop_block.py * improve dropblock * add dropblock * refactor dropblock * fix doc * remove the limitation of block_size * Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis <[email protected]> * fix lint * fix lint * add dropblock * Fix linter * add dropblock random check * reduce test time * Update test_ops.py * speed the dropblock test * fix lint Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 1fc53b2 commit 5568744

File tree

4 files changed

+257
-0
lines changed

4 files changed

+257
-0
lines changed

docs/source/ops.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ Operators
2020
box_iou
2121
clip_boxes_to_image
2222
deform_conv2d
23+
drop_block2d
24+
drop_block3d
2325
generalized_box_iou
2426
generalized_box_iou_loss
2527
masks_to_boxes
@@ -48,3 +50,5 @@ Operators
4850
Conv2dNormActivation
4951
Conv3dNormActivation
5052
SqueezeExcitation
53+
DropBlock2d
54+
DropBlock3d

test/test_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from abc import ABC, abstractmethod
44
from functools import lru_cache
5+
from itertools import product
56
from typing import Callable, List, Tuple
67

78
import numpy as np
@@ -57,6 +58,16 @@ def forward(self, a):
5758
self.layer(a)
5859

5960

61+
class DropBlockWrapper(nn.Module):
62+
def __init__(self, obj):
63+
super().__init__()
64+
self.layer = obj
65+
self.n_inputs = 1
66+
67+
def forward(self, a):
68+
self.layer(a)
69+
70+
6071
class RoIOpTester(ABC):
6172
dtype = torch.float64
6273

@@ -1357,5 +1368,87 @@ def test_split_normalization_params(self, norm_layer):
13571368
assert len(params[1]) == 82
13581369

13591370

1371+
class TestDropBlock:
1372+
@pytest.mark.parametrize("seed", range(10))
1373+
@pytest.mark.parametrize("dim", [2, 3])
1374+
@pytest.mark.parametrize("p", [0, 0.5])
1375+
@pytest.mark.parametrize("block_size", [5, 11])
1376+
@pytest.mark.parametrize("inplace", [True, False])
1377+
def test_drop_block(self, seed, dim, p, block_size, inplace):
1378+
torch.manual_seed(seed)
1379+
batch_size = 5
1380+
channels = 3
1381+
height = 11
1382+
width = height
1383+
depth = height
1384+
if dim == 2:
1385+
x = torch.ones(size=(batch_size, channels, height, width))
1386+
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
1387+
feature_size = height * width
1388+
elif dim == 3:
1389+
x = torch.ones(size=(batch_size, channels, depth, height, width))
1390+
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
1391+
feature_size = depth * height * width
1392+
layer.__repr__()
1393+
1394+
out = layer(x)
1395+
if p == 0:
1396+
assert out.equal(x)
1397+
if block_size == height:
1398+
for b, c in product(range(batch_size), range(channels)):
1399+
assert out[b, c].count_nonzero() in (0, feature_size)
1400+
1401+
@pytest.mark.parametrize("seed", range(10))
1402+
@pytest.mark.parametrize("dim", [2, 3])
1403+
@pytest.mark.parametrize("p", [0.1, 0.2])
1404+
@pytest.mark.parametrize("block_size", [3])
1405+
@pytest.mark.parametrize("inplace", [False])
1406+
def test_drop_block_random(self, seed, dim, p, block_size, inplace):
1407+
torch.manual_seed(seed)
1408+
batch_size = 5
1409+
channels = 3
1410+
height = 11
1411+
width = height
1412+
depth = height
1413+
if dim == 2:
1414+
x = torch.ones(size=(batch_size, channels, height, width))
1415+
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
1416+
elif dim == 3:
1417+
x = torch.ones(size=(batch_size, channels, depth, height, width))
1418+
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
1419+
1420+
trials = 250
1421+
num_samples = 0
1422+
counts = 0
1423+
cell_numel = torch.tensor(x.shape).prod()
1424+
for _ in range(trials):
1425+
with torch.no_grad():
1426+
out = layer(x)
1427+
non_zero_count = out.nonzero().size(0)
1428+
counts += cell_numel - non_zero_count
1429+
num_samples += cell_numel
1430+
1431+
assert abs(p - counts / num_samples) / p < 0.15
1432+
1433+
def make_obj(self, dim, p, block_size, inplace, wrap=False):
1434+
if dim == 2:
1435+
obj = ops.DropBlock2d(p, block_size, inplace)
1436+
elif dim == 3:
1437+
obj = ops.DropBlock3d(p, block_size, inplace)
1438+
return DropBlockWrapper(obj) if wrap else obj
1439+
1440+
@pytest.mark.parametrize("dim", (2, 3))
1441+
@pytest.mark.parametrize("p", [0, 1])
1442+
@pytest.mark.parametrize("block_size", [5, 7])
1443+
@pytest.mark.parametrize("inplace", [True, False])
1444+
def test_is_leaf_node(self, dim, p, block_size, inplace):
1445+
op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
1446+
graph_node_names = get_graph_node_names(op_obj)
1447+
1448+
assert len(graph_node_names) == 2
1449+
assert len(graph_node_names[0]) == len(graph_node_names[1])
1450+
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1451+
1452+
13601453
if __name__ == "__main__":
13611454
pytest.main([__file__])

torchvision/ops/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from .boxes import box_convert
1313
from .deform_conv import deform_conv2d, DeformConv2d
14+
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
1415
from .feature_pyramid_network import FeaturePyramidNetwork
1516
from .focal_loss import sigmoid_focal_loss
1617
from .giou_loss import generalized_box_iou_loss
@@ -55,4 +56,8 @@
5556
"Conv3dNormActivation",
5657
"SqueezeExcitation",
5758
"generalized_box_iou_loss",
59+
"drop_block2d",
60+
"DropBlock2d",
61+
"drop_block3d",
62+
"DropBlock3d",
5863
]

torchvision/ops/drop_block.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import torch.fx
3+
import torch.nn.functional as F
4+
from torch import nn, Tensor
5+
6+
from ..utils import _log_api_usage_once
7+
8+
9+
def drop_block2d(
10+
input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
11+
) -> Tensor:
12+
"""
13+
Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks"
14+
<https://arxiv.org/abs/1810.12890>`.
15+
16+
Args:
17+
input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one
18+
being its batch i.e. a batch with ``N`` rows.
19+
p (float): Probability of an element to be dropped.
20+
block_size (int): Size of the block to drop.
21+
inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
22+
eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
23+
training (bool): apply dropblock if is ``True``. Default: ``True``.
24+
25+
Returns:
26+
Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock.
27+
"""
28+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
29+
_log_api_usage_once(drop_block2d)
30+
if p < 0.0 or p > 1.0:
31+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
32+
if input.ndim != 4:
33+
raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.")
34+
if not training or p == 0.0:
35+
return input
36+
37+
N, C, H, W = input.size()
38+
block_size = min(block_size, W, H)
39+
# compute the gamma of Bernoulli distribution
40+
gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1)))
41+
noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device)
42+
noise.bernoulli_(gamma)
43+
44+
noise = F.pad(noise, [block_size // 2] * 4, value=0)
45+
noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2)
46+
noise = 1 - noise
47+
normalize_scale = noise.numel() / (eps + noise.sum())
48+
if inplace:
49+
input.mul_(noise).mul_(normalize_scale)
50+
else:
51+
input = input * noise * normalize_scale
52+
return input
53+
54+
55+
def drop_block3d(
56+
input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
57+
) -> Tensor:
58+
"""
59+
Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks"
60+
<https://arxiv.org/abs/1810.12890>`.
61+
62+
Args:
63+
input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one
64+
being its batch i.e. a batch with ``N`` rows.
65+
p (float): Probability of an element to be dropped.
66+
block_size (int): Size of the block to drop.
67+
inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
68+
eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
69+
training (bool): apply dropblock if is ``True``. Default: ``True``.
70+
71+
Returns:
72+
Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock.
73+
"""
74+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
75+
_log_api_usage_once(drop_block3d)
76+
if p < 0.0 or p > 1.0:
77+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
78+
if input.ndim != 5:
79+
raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.")
80+
if not training or p == 0.0:
81+
return input
82+
83+
N, C, D, H, W = input.size()
84+
block_size = min(block_size, D, H, W)
85+
# compute the gamma of Bernoulli distribution
86+
gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1)))
87+
noise = torch.empty(
88+
(N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device
89+
)
90+
noise.bernoulli_(gamma)
91+
92+
noise = F.pad(noise, [block_size // 2] * 6, value=0)
93+
noise = F.max_pool3d(
94+
noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2
95+
)
96+
noise = 1 - noise
97+
normalize_scale = noise.numel() / (eps + noise.sum())
98+
if inplace:
99+
input.mul_(noise).mul_(normalize_scale)
100+
else:
101+
input = input * noise * normalize_scale
102+
return input
103+
104+
105+
torch.fx.wrap("drop_block2d")
106+
107+
108+
class DropBlock2d(nn.Module):
109+
"""
110+
See :func:`drop_block2d`.
111+
"""
112+
113+
def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
114+
super().__init__()
115+
116+
self.p = p
117+
self.block_size = block_size
118+
self.inplace = inplace
119+
self.eps = eps
120+
121+
def forward(self, input: Tensor) -> Tensor:
122+
"""
123+
Args:
124+
input (Tensor): Input feature map on which some areas will be randomly
125+
dropped.
126+
Returns:
127+
Tensor: The tensor after DropBlock layer.
128+
"""
129+
return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training)
130+
131+
def __repr__(self) -> str:
132+
s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})"
133+
return s
134+
135+
136+
torch.fx.wrap("drop_block3d")
137+
138+
139+
class DropBlock3d(DropBlock2d):
140+
"""
141+
See :func:`drop_block3d`.
142+
"""
143+
144+
def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
145+
super().__init__(p, block_size, inplace, eps)
146+
147+
def forward(self, input: Tensor) -> Tensor:
148+
"""
149+
Args:
150+
input (Tensor): Input feature map on which some areas will be randomly
151+
dropped.
152+
Returns:
153+
Tensor: The tensor after DropBlock layer.
154+
"""
155+
return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training)

0 commit comments

Comments
 (0)