Skip to content

Commit 4d4285a

Browse files
committed
Enable to toggle whether Gem Pooling is trainable or not by requires_grad
1 parent 590c771 commit 4d4285a

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

mmcls/models/necks/gem.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@ class GeneralizedMeanPooling(nn.Module):
3535
def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True):
3636
assert p >= 1, "'p' must be a value greater then 1"
3737
super(GeneralizedMeanPooling, self).__init__()
38-
if p_trainable:
39-
self.p = Parameter(torch.ones(1) * p)
40-
else:
41-
self.p = p
38+
self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable)
4239
self.eps = eps
4340
self.clamp = clamp
4441
self.p_trainable = p_trainable

tests/test_models/test_necks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_gem_neck():
4646
neck = GeneralizedMeanPooling()
4747

4848
# default p is trainable
49-
assert isinstance(neck.p, torch.nn.Parameter)
49+
assert neck.p.requires_grad
5050

5151
# batch_size, num_features, feature_size(2)
5252
fake_input = torch.rand(1, 16, 24, 24)
@@ -68,8 +68,8 @@ def test_gem_neck():
6868
# test gem_neck with p_trainable=False
6969
neck = GeneralizedMeanPooling(p_trainable=False)
7070

71-
# p is not trainable(float)
72-
assert isinstance(neck.p, float)
71+
# p is not trainable
72+
assert not neck.p.requires_grad
7373

7474
# batch_size, num_features, feature_size(2)
7575
fake_input = torch.rand(1, 16, 24, 24)

0 commit comments

Comments
 (0)