Skip to content

Commit b000781

Browse files
Ezra-Yumzr1996
andauthored
[Enhance] Enhance ArcFaceClsHead. (#1181)
* update arcface * fix unit tests * add adv-margins add adv-margins update arcface * rebase * update doc and fix ut * rebase * update code * rebase * use label data * update set-margins * Modify Arcface related method names. Co-authored-by: mzr1996 <[email protected]>
1 parent 4fb44f8 commit b000781

File tree

10 files changed

+535
-185
lines changed

10 files changed

+535
-185
lines changed

docs/en/api/engine.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ Hooks
3131
ClassNumCheckHook
3232
PreciseBNHook
3333
VisualizationHook
34-
SwitchRecipeHook
34+
PrepareProtoBeforeValLoopHook
35+
SetAdaptiveMarginsHook
3536

3637
.. module:: mmcls.engine.optimizers
3738

docs/en/api/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ Heads
140140
EfficientFormerClsHead
141141
DeiTClsHead
142142
ConformerHead
143+
ArcFaceClsHead
143144
MultiLabelClsHead
144145
MultiLabelLinearClsHead
145146
CSRAClsHead

mmcls/engine/hooks/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .class_num_check_hook import ClassNumCheckHook
3+
from .margin_head_hooks import SetAdaptiveMarginsHook
34
from .precise_bn_hook import PreciseBNHook
45
from .retriever_hooks import PrepareProtoBeforeValLoopHook
56
from .switch_recipe_hook import SwitchRecipeHook
67
from .visualization_hook import VisualizationHook
78

89
__all__ = [
910
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
10-
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook'
11+
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook',
12+
'SetAdaptiveMarginsHook'
1113
]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) OpenMMLab. All rights reserved
2+
import numpy as np
3+
from mmengine.hooks import Hook
4+
from mmengine.model import is_model_wrapper
5+
6+
from mmcls.models.heads import ArcFaceClsHead
7+
from mmcls.registry import HOOKS
8+
9+
10+
@HOOKS.register_module()
11+
class SetAdaptiveMarginsHook(Hook):
12+
r"""Set adaptive-margins in ArcFaceClsHead based on the power of
13+
category-wise count.
14+
15+
A PyTorch implementation of paper `Google Landmark Recognition 2020
16+
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_.
17+
The margins will be
18+
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`.
19+
The `n` indicates the number of occurrences of a category.
20+
21+
Args:
22+
margin_min (float): Lower bound of margins. Defaults to 0.05.
23+
margin_max (float): Upper bound of margins. Defaults to 0.5.
24+
power (float): The power of category freqercy. Defaults to -0.25.
25+
"""
26+
27+
def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
28+
self.margin_min = margin_min
29+
self.margin_max = margin_max
30+
self.margin_range = margin_max - margin_min
31+
self.p = power
32+
33+
def before_train(self, runner):
34+
"""change the margins in ArcFaceClsHead.
35+
36+
Args:
37+
runner (obj: `Runner`): Runner.
38+
"""
39+
model = runner.model
40+
if is_model_wrapper(model):
41+
model = model.module
42+
43+
if (hasattr(model, 'head')
44+
and not isinstance(model.head, ArcFaceClsHead)):
45+
raise ValueError(
46+
'Hook ``SetFreqPowAdvMarginsHook`` could only be used '
47+
f'for ``ArcFaceClsHead``, but get {type(model.head)}')
48+
49+
# generate margins base on the dataset.
50+
gt_labels = runner.train_dataloader.dataset.get_gt_labels()
51+
label_count = np.bincount(gt_labels)
52+
label_count[label_count == 0] = 1 # At least one occurrence
53+
pow_freq = np.power(label_count, self.p)
54+
55+
min_f, max_f = pow_freq.min(), pow_freq.max()
56+
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
57+
margins = normized_pow_freq * self.margin_range + self.margin_min
58+
59+
assert len(margins) == runner.model.head.num_classes
60+
61+
model.head.set_margins(margins)

mmcls/models/backbones/hornet.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,20 +250,24 @@ def forward(self, x):
250250

251251
@MODELS.register_module()
252252
class HorNet(BaseBackbone):
253-
"""HorNet
254-
A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions
255-
with Recursive Gated Convolutions`
256-
Inspiration from
257-
https://github.com/raoyongming/HorNet
253+
"""HorNet.
254+
255+
A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial
256+
Interactions with Recursive Gated Convolutions
257+
<https://arxiv.org/abs/2207.14284>`_ .
258+
Inspiration from https://github.com/raoyongming/HorNet
259+
258260
Args:
259261
arch (str | dict): HorNet architecture.
262+
260263
If use string, choose from 'tiny', 'small', 'base' and 'large'.
261264
If use dict, it should have below keys:
262265
- **base_dim** (int): The base dimensions of embedding.
263266
- **depths** (List[int]): The number of blocks in each stage.
264267
- **orders** (List[int]): The number of order of gnConv in each
265268
stage.
266269
- **dw_cfg** (List[dict]): The Config for dw conv.
270+
267271
Defaults to 'tiny'.
268272
in_channels (int): Number of input image channels. Defaults to 3.
269273
drop_path_rate (float): Stochastic depth rate. Defaults to 0.

mmcls/models/heads/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from .arcface_head import ArcFaceClsHead
32
from .cls_head import ClsHead
43
from .conformer_head import ConformerHead
54
from .deit_head import DeiTClsHead
65
from .efficientformer_head import EfficientFormerClsHead
76
from .linear_head import LinearClsHead
7+
from .margin_head import ArcFaceClsHead
88
from .multi_label_cls_head import MultiLabelClsHead
99
from .multi_label_csra_head import CSRAClsHead
1010
from .multi_label_linear_head import MultiLabelLinearClsHead

mmcls/models/heads/arcface_head.py

Lines changed: 0 additions & 176 deletions
This file was deleted.

0 commit comments

Comments
 (0)