Skip to content

Commit 2495400

Browse files
committed
Merge branch 'dev'
2 parents 8387358 + c737e65 commit 2495400

File tree

17 files changed

+157
-81
lines changed

17 files changed

+157
-81
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ repos:
2929
rev: 0.7.9
3030
hooks:
3131
- id: mdformat
32-
args: ["--number", "--table-width", "200"]
32+
args: ["--number", "--table-width", "200", '--disable-escape', 'backslash', '--disable-escape', 'link-enclosure']
3333
additional_dependencies:
34-
- mdformat-openmmlab
34+
- "mdformat-openmmlab>=0.0.4"
3535
- mdformat_frontmatter
3636
- linkify-it-py
3737
- repo: https://github.com/codespell-project/codespell

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ The MMClassification 1.0 has released! It's still unstable and in release candid
6464
to [the 1.x branch](https://github.com/open-mmlab/mmclassification/tree/1.x) and discuss it with us in
6565
[the discussion](https://github.com/open-mmlab/mmclassification/discussions).
6666

67+
v0.25.0 was released in 06/12/2022.
68+
Highlights of the new version:
69+
70+
- Support MLU backend.
71+
- Add `dist_train_arm.sh` for ARM device.
72+
6773
v0.24.1 was released in 31/10/2022.
6874
Highlights of the new version:
6975

@@ -75,13 +81,6 @@ Highlights of the new version:
7581
- Support **HorNet**, **EfficientFormerm**, **SwinTransformer V2** and **MViT** backbones.
7682
- Support Standford Cars dataset.
7783

78-
v0.23.0 was released in 1/5/2022.
79-
Highlights of the new version:
80-
81-
- Support **DenseNet**, **VAN** and **PoolFormer**, and provide pre-trained models.
82-
- Support training on IPU.
83-
- New style API docs, welcome [view it](https://mmclassification.readthedocs.io/en/master/api/models.html).
84-
8584
Please refer to [changelog.md](docs/en/changelog.md) for more details and other release history.
8685

8786
## Installation

README_zh-CN.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
6363

6464
MMClassification 1.0 已经发布!目前仍在公测中,如果希望试用,请切换到 [1.x 分支](https://github.com/open-mmlab/mmclassification/tree/1.x),并在[讨论版](https://github.com/open-mmlab/mmclassification/discussions) 参加开发讨论!
6565

66+
2022/12/06 发布了 v0.25.0 版本
67+
68+
- 支持 MLU 设备
69+
- 添加了用于 ARM 设备训练的 `dist_train_arm.sh`
70+
6671
2022/10/31 发布了 v0.24.1 版本
6772

6873
- 支持了华为昇腾 NPU 设备。

configs/t2t_vit/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
## Abstract
88

9-
Transformers, which are popular for language modeling, have been explored for solving vision tasks recently, \\eg, the Vision Transformer (ViT) for image classification. The ViT model splits each image into a sequence of tokens with fixed length and then applies multiple Transformer layers to model their global relation for classification. However, ViT achieves inferior performance to CNNs when trained from scratch on a midsize dataset like ImageNet. We find it is because: 1) the simple tokenization of input images fails to model the important local structure such as edges and lines among neighboring pixels, leading to low training sample efficiency; 2) the redundant attention backbone design of ViT leads to limited feature richness for fixed computation budgets and limited training samples. To overcome such limitations, we propose a new Tokens-To-Token Vision Transformer (T2T-ViT), which incorporates 1) a layer-wise Tokens-to-Token (T2T) transformation to progressively structurize the image to tokens by recursively aggregating neighboring Tokens into one Token (Tokens-to-Token), such that local structure represented by surrounding tokens can be modeled and tokens length can be reduced; 2) an efficient backbone with a deep-narrow structure for vision transformer motivated by CNN architecture design after empirical study. Notably, T2T-ViT reduces the parameter count and MACs of vanilla ViT by half, while achieving more than 3.0% improvement when trained from scratch on ImageNet. It also outperforms ResNets and achieves comparable performance with MobileNets by directly training on ImageNet. For example, T2T-ViT with comparable size to ResNet50 (21.5M parameters) can achieve 83.3% top1 accuracy in image resolution 384×384 on ImageNet.
9+
Transformers, which are popular for language modeling, have been explored for solving vision tasks recently, e.g., the Vision Transformer (ViT) for image classification. The ViT model splits each image into a sequence of tokens with fixed length and then applies multiple Transformer layers to model their global relation for classification. However, ViT achieves inferior performance to CNNs when trained from scratch on a midsize dataset like ImageNet. We find it is because: 1) the simple tokenization of input images fails to model the important local structure such as edges and lines among neighboring pixels, leading to low training sample efficiency; 2) the redundant attention backbone design of ViT leads to limited feature richness for fixed computation budgets and limited training samples. To overcome such limitations, we propose a new Tokens-To-Token Vision Transformer (T2T-ViT), which incorporates 1) a layer-wise Tokens-to-Token (T2T) transformation to progressively structurize the image to tokens by recursively aggregating neighboring Tokens into one Token (Tokens-to-Token), such that local structure represented by surrounding tokens can be modeled and tokens length can be reduced; 2) an efficient backbone with a deep-narrow structure for vision transformer motivated by CNN architecture design after empirical study. Notably, T2T-ViT reduces the parameter count and MACs of vanilla ViT by half, while achieving more than 3.0% improvement when trained from scratch on ImageNet. It also outperforms ResNets and achieves comparable performance with MobileNets by directly training on ImageNet. For example, T2T-ViT with comparable size to ResNet50 (21.5M parameters) can achieve 83.3% top1 accuracy in image resolution 384×384 on ImageNet.
1010

1111
<div align=center>
1212
<img src="https://user-images.githubusercontent.com/26739999/142578381-e9040610-05d9-457c-8bf5-01c2fa94add2.png" width="60%"/>

docker/serve/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ARG CUDNN="7"
44
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
55

66
ARG MMCV="1.7.0"
7-
ARG MMCLS="0.24.1"
7+
ARG MMCLS="0.25.0"
88

99
ENV PYTHONUNBUFFERED TRUE
1010

docs/en/changelog.md

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
# Changelog
22

3+
## v0.25.0(06/12/2022)
4+
5+
### Highlights
6+
7+
- Support MLU backend.
8+
9+
### New Features
10+
11+
- Support MLU backend. ([#1159](https://github.com/open-mmlab/mmclassification/pull/1159))
12+
- Support Activation Checkpointing for ConvNeXt. ([#1152](https://github.com/open-mmlab/mmclassification/pull/1152))
13+
14+
### Improvements
15+
16+
- Add `dist_train_arm.sh` for ARM device and update NPU results. ([#1218](https://github.com/open-mmlab/mmclassification/pull/1218))
17+
18+
### Bug Fixes
19+
20+
- Fix a bug caused `MMClsWandbHook` stuck. ([#1242](https://github.com/open-mmlab/mmclassification/pull/1242))
21+
- Fix the redundant `device_ids` in `tools/test.py`. ([#1215](https://github.com/open-mmlab/mmclassification/pull/1215))
22+
23+
### Docs Update
24+
25+
- Add version banner and version warning in master docs. ([#1216](https://github.com/open-mmlab/mmclassification/pull/1216))
26+
- Update NPU support doc. ([#1198](https://github.com/open-mmlab/mmclassification/pull/1198))
27+
- Fixed typo in `pytorch2torchscript.md`. ([#1173](https://github.com/open-mmlab/mmclassification/pull/1173))
28+
- Fix typo in `miscellaneous.md`. ([#1137](https://github.com/open-mmlab/mmclassification/pull/1137))
29+
- further detail for the doc for `ClassBalancedDataset`. ([#901](https://github.com/open-mmlab/mmclassification/pull/901))
30+
331
## v0.24.1(31/10/2022)
432

533
### New Features
@@ -28,14 +56,14 @@
2856

2957
### Improvements
3058

31-
- \[Improve\] replace loop of progressbar in api/test. ([#878](https://github.com/open-mmlab/mmclassification/pull/878))
32-
- \[Enhance\] RepVGG for YOLOX-PAI. ([#1025](https://github.com/open-mmlab/mmclassification/pull/1025))
33-
- \[Enhancement\] Update VAN. ([#1017](https://github.com/open-mmlab/mmclassification/pull/1017))
34-
- \[Refactor\] Re-write `get_sinusoid_encoding` from third-party implementation. ([#965](https://github.com/open-mmlab/mmclassification/pull/965))
35-
- \[Improve\] Upgrade onnxsim to v0.4.0. ([#915](https://github.com/open-mmlab/mmclassification/pull/915))
36-
- \[Improve\] Fixed typo in `RepVGG`. ([#985](https://github.com/open-mmlab/mmclassification/pull/985))
37-
- \[Improve\] Using `train_step` instead of `forward` in PreciseBNHook ([#964](https://github.com/open-mmlab/mmclassification/pull/964))
38-
- \[Improve\] Use `forward_dummy` to calculate FLOPS. ([#953](https://github.com/open-mmlab/mmclassification/pull/953))
59+
- [Improve] replace loop of progressbar in api/test. ([#878](https://github.com/open-mmlab/mmclassification/pull/878))
60+
- [Enhance] RepVGG for YOLOX-PAI. ([#1025](https://github.com/open-mmlab/mmclassification/pull/1025))
61+
- [Enhancement] Update VAN. ([#1017](https://github.com/open-mmlab/mmclassification/pull/1017))
62+
- [Refactor] Re-write `get_sinusoid_encoding` from third-party implementation. ([#965](https://github.com/open-mmlab/mmclassification/pull/965))
63+
- [Improve] Upgrade onnxsim to v0.4.0. ([#915](https://github.com/open-mmlab/mmclassification/pull/915))
64+
- [Improve] Fixed typo in `RepVGG`. ([#985](https://github.com/open-mmlab/mmclassification/pull/985))
65+
- [Improve] Using `train_step` instead of `forward` in PreciseBNHook ([#964](https://github.com/open-mmlab/mmclassification/pull/964))
66+
- [Improve] Use `forward_dummy` to calculate FLOPS. ([#953](https://github.com/open-mmlab/mmclassification/pull/953))
3967

4068
### Bug Fixes
4169

@@ -102,13 +130,13 @@
102130

103131
### New Features
104132

105-
- \[Feature\] Support resize relative position embedding in `SwinTransformer`. ([#749](https://github.com/open-mmlab/mmclassification/pull/749))
106-
- \[Feature\] Add PoolFormer backbone and checkpoints. ([#746](https://github.com/open-mmlab/mmclassification/pull/746))
133+
- [Feature] Support resize relative position embedding in `SwinTransformer`. ([#749](https://github.com/open-mmlab/mmclassification/pull/749))
134+
- [Feature] Add PoolFormer backbone and checkpoints. ([#746](https://github.com/open-mmlab/mmclassification/pull/746))
107135

108136
### Improvements
109137

110-
- \[Enhance\] Improve CPE performance by reduce memory copy. ([#762](https://github.com/open-mmlab/mmclassification/pull/762))
111-
- \[Enhance\] Add extra dataloader settings in configs. ([#752](https://github.com/open-mmlab/mmclassification/pull/752))
138+
- [Enhance] Improve CPE performance by reduce memory copy. ([#762](https://github.com/open-mmlab/mmclassification/pull/762))
139+
- [Enhance] Add extra dataloader settings in configs. ([#752](https://github.com/open-mmlab/mmclassification/pull/752))
112140

113141
## v0.22.0(30/3/2022)
114142

@@ -120,29 +148,29 @@
120148

121149
### New Features
122150

123-
- \[Feature\] Add CSPNet and backbone and checkpoints ([#735](https://github.com/open-mmlab/mmclassification/pull/735))
124-
- \[Feature\] Add `CustomDataset`. ([#738](https://github.com/open-mmlab/mmclassification/pull/738))
125-
- \[Feature\] Add diff seeds to diff ranks. ([#744](https://github.com/open-mmlab/mmclassification/pull/744))
126-
- \[Feature\] Support ConvMixer. ([#716](https://github.com/open-mmlab/mmclassification/pull/716))
127-
- \[Feature\] Our `dist_train` & `dist_test` tools support distributed training on multiple machines. ([#734](https://github.com/open-mmlab/mmclassification/pull/734))
128-
- \[Feature\] Add RepMLP backbone and checkpoints. ([#709](https://github.com/open-mmlab/mmclassification/pull/709))
129-
- \[Feature\] Support CUB dataset. ([#703](https://github.com/open-mmlab/mmclassification/pull/703))
130-
- \[Feature\] Support ResizeMix. ([#676](https://github.com/open-mmlab/mmclassification/pull/676))
151+
- [Feature] Add CSPNet and backbone and checkpoints ([#735](https://github.com/open-mmlab/mmclassification/pull/735))
152+
- [Feature] Add `CustomDataset`. ([#738](https://github.com/open-mmlab/mmclassification/pull/738))
153+
- [Feature] Add diff seeds to diff ranks. ([#744](https://github.com/open-mmlab/mmclassification/pull/744))
154+
- [Feature] Support ConvMixer. ([#716](https://github.com/open-mmlab/mmclassification/pull/716))
155+
- [Feature] Our `dist_train` & `dist_test` tools support distributed training on multiple machines. ([#734](https://github.com/open-mmlab/mmclassification/pull/734))
156+
- [Feature] Add RepMLP backbone and checkpoints. ([#709](https://github.com/open-mmlab/mmclassification/pull/709))
157+
- [Feature] Support CUB dataset. ([#703](https://github.com/open-mmlab/mmclassification/pull/703))
158+
- [Feature] Support ResizeMix. ([#676](https://github.com/open-mmlab/mmclassification/pull/676))
131159

132160
### Improvements
133161

134-
- \[Enhance\] Use `--a-b` instead of `--a_b` in arguments. ([#754](https://github.com/open-mmlab/mmclassification/pull/754))
135-
- \[Enhance\] Add `get_cat_ids` and `get_gt_labels` to KFoldDataset. ([#721](https://github.com/open-mmlab/mmclassification/pull/721))
136-
- \[Enhance\] Set torch seed in `worker_init_fn`. ([#733](https://github.com/open-mmlab/mmclassification/pull/733))
162+
- [Enhance] Use `--a-b` instead of `--a_b` in arguments. ([#754](https://github.com/open-mmlab/mmclassification/pull/754))
163+
- [Enhance] Add `get_cat_ids` and `get_gt_labels` to KFoldDataset. ([#721](https://github.com/open-mmlab/mmclassification/pull/721))
164+
- [Enhance] Set torch seed in `worker_init_fn`. ([#733](https://github.com/open-mmlab/mmclassification/pull/733))
137165

138166
### Bug Fixes
139167

140-
- \[Fix\] Fix the discontiguous output feature map of ConvNeXt. ([#743](https://github.com/open-mmlab/mmclassification/pull/743))
168+
- [Fix] Fix the discontiguous output feature map of ConvNeXt. ([#743](https://github.com/open-mmlab/mmclassification/pull/743))
141169

142170
### Docs Update
143171

144-
- \[Docs\] Add brief installation steps in README for copy&paste. ([#755](https://github.com/open-mmlab/mmclassification/pull/755))
145-
- \[Docs\] fix logo url link from mmocr to mmcls. ([#732](https://github.com/open-mmlab/mmclassification/pull/732))
172+
- [Docs] Add brief installation steps in README for copy&paste. ([#755](https://github.com/open-mmlab/mmclassification/pull/755))
173+
- [Docs] fix logo url link from mmocr to mmcls. ([#732](https://github.com/open-mmlab/mmclassification/pull/732))
146174

147175
## v0.21.0(04/03/2022)
148176

@@ -245,18 +273,18 @@
245273

246274
### Improvements
247275

248-
- \[Reproduction\] Reproduce RegNetX training accuracy. ([#587](https://github.com/open-mmlab/mmclassification/pull/587))
249-
- \[Reproduction\] Reproduce training results of T2T-ViT. ([#610](https://github.com/open-mmlab/mmclassification/pull/610))
250-
- \[Enhance\] Provide high-acc training settings of ResNet. ([#572](https://github.com/open-mmlab/mmclassification/pull/572))
251-
- \[Enhance\] Set a random seed when the user does not set a seed. ([#554](https://github.com/open-mmlab/mmclassification/pull/554))
252-
- \[Enhance\] Added `NumClassCheckHook` and unit tests. ([#559](https://github.com/open-mmlab/mmclassification/pull/559))
253-
- \[Enhance\] Enhance feature extraction function. ([#593](https://github.com/open-mmlab/mmclassification/pull/593))
254-
- \[Enhance\] Improve efficiency of precision, recall, f1_score and support. ([#595](https://github.com/open-mmlab/mmclassification/pull/595))
255-
- \[Enhance\] Improve accuracy calculation performance. ([#592](https://github.com/open-mmlab/mmclassification/pull/592))
256-
- \[Refactor\] Refactor `analysis_log.py`. ([#529](https://github.com/open-mmlab/mmclassification/pull/529))
257-
- \[Refactor\] Use new API of matplotlib to handle blocking input in visualization. ([#568](https://github.com/open-mmlab/mmclassification/pull/568))
258-
- \[CI\] Cancel previous runs that are not completed. ([#583](https://github.com/open-mmlab/mmclassification/pull/583))
259-
- \[CI\] Skip build CI if only configs or docs modification. ([#575](https://github.com/open-mmlab/mmclassification/pull/575))
276+
- [Reproduction] Reproduce RegNetX training accuracy. ([#587](https://github.com/open-mmlab/mmclassification/pull/587))
277+
- [Reproduction] Reproduce training results of T2T-ViT. ([#610](https://github.com/open-mmlab/mmclassification/pull/610))
278+
- [Enhance] Provide high-acc training settings of ResNet. ([#572](https://github.com/open-mmlab/mmclassification/pull/572))
279+
- [Enhance] Set a random seed when the user does not set a seed. ([#554](https://github.com/open-mmlab/mmclassification/pull/554))
280+
- [Enhance] Added `NumClassCheckHook` and unit tests. ([#559](https://github.com/open-mmlab/mmclassification/pull/559))
281+
- [Enhance] Enhance feature extraction function. ([#593](https://github.com/open-mmlab/mmclassification/pull/593))
282+
- [Enhance] Improve efficiency of precision, recall, f1_score and support. ([#595](https://github.com/open-mmlab/mmclassification/pull/595))
283+
- [Enhance] Improve accuracy calculation performance. ([#592](https://github.com/open-mmlab/mmclassification/pull/592))
284+
- [Refactor] Refactor `analysis_log.py`. ([#529](https://github.com/open-mmlab/mmclassification/pull/529))
285+
- [Refactor] Use new API of matplotlib to handle blocking input in visualization. ([#568](https://github.com/open-mmlab/mmclassification/pull/568))
286+
- [CI] Cancel previous runs that are not completed. ([#583](https://github.com/open-mmlab/mmclassification/pull/583))
287+
- [CI] Skip build CI if only configs or docs modification. ([#575](https://github.com/open-mmlab/mmclassification/pull/575))
260288

261289
### Bug Fixes
262290

docs/en/faq.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ and make sure you fill in all required information in the template.
1818
| MMClassification version | MMCV version |
1919
| :----------------------: | :--------------------: |
2020
| dev | mmcv>=1.7.0, \<1.9.0 |
21-
| 0.24.1 (master) | mmcv>=1.4.2, \<1.9.0 |
21+
| 0.25.0 (master) | mmcv>=1.4.2, \<1.9.0 |
22+
| 0.24.1 | mmcv>=1.4.2, \<1.9.0 |
2223
| 0.23.2 | mmcv>=1.4.2, \<1.7.0 |
2324
| 0.22.1 | mmcv>=1.4.2, \<1.6.0 |
2425
| 0.21.0 | mmcv>=1.4.2, \<=1.5.0 |

docs/zh_CN/faq.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
| MMClassification version | MMCV version |
1717
| :----------------------: | :--------------------: |
1818
| dev | mmcv>=1.7.0, \<1.9.0 |
19-
| 0.24.1 (master) | mmcv>=1.4.2, \<1.9.0 |
19+
| 0.25.0 (master) | mmcv>=1.4.2, \<1.9.0 |
20+
| 0.24.1 | mmcv>=1.4.2, \<1.9.0 |
2021
| 0.23.2 | mmcv>=1.4.2, \<1.7.0 |
2122
| 0.22.1 | mmcv>=1.4.2, \<1.6.0 |
2223
| 0.21.0 | mmcv>=1.4.2, \<=1.5.0 |

mmcls/apis/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
1212
from mmcls.datasets import build_dataloader, build_dataset
13-
from mmcls.utils import (get_root_logger, wrap_distributed_model,
14-
wrap_non_distributed_model)
13+
from mmcls.utils import (auto_select_device, get_root_logger,
14+
wrap_distributed_model, wrap_non_distributed_model)
1515

1616

17-
def init_random_seed(seed=None, device='cuda'):
17+
def init_random_seed(seed=None, device=None):
1818
"""Initialize random seed.
1919
2020
If the seed is not set, the seed will be automatically randomized,
@@ -30,7 +30,8 @@ def init_random_seed(seed=None, device='cuda'):
3030
"""
3131
if seed is not None:
3232
return seed
33-
33+
if device is None:
34+
device = auto_select_device()
3435
# Make sure all ranks share the same random seed to prevent
3536
# some potential bugs. Please refer to
3637
# https://github.com/open-mmlab/mmdetection/issues/6339

mmcls/core/hook/wandblogger_hook.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55
from mmcv.runner import HOOKS, BaseRunner
6-
from mmcv.runner.dist_utils import master_only
6+
from mmcv.runner.dist_utils import get_dist_info, master_only
77
from mmcv.runner.hooks.checkpoint import CheckpointHook
88
from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook
99
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
@@ -190,7 +190,6 @@ def after_train_epoch(self, runner):
190190
# Log the evaluation table
191191
self._log_eval_table(runner.epoch + 1)
192192

193-
@master_only
194193
def after_train_iter(self, runner):
195194
if self.get_mode(runner) == 'train':
196195
# An ugly patch. The iter-based eval hook will call the
@@ -201,6 +200,10 @@ def after_train_iter(self, runner):
201200
else:
202201
super(MMClsWandbHook, self).after_train_iter(runner)
203202

203+
rank, _ = get_dist_info()
204+
if rank != 0:
205+
return
206+
204207
if self.by_epoch:
205208
return
206209

0 commit comments

Comments
 (0)