Skip to content

Commit 95f6078

Browse files
committed
[Feature] Support YOLOv7 inference (open-mmlab#149)
* update * update * update * update * update * add docstr * fix comments * update
1 parent 3ea6edf commit 95f6078

File tree

11 files changed

+1111
-18
lines changed

11 files changed

+1111
-18
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
_base_ = '../_base_/default_runtime.py'
2+
3+
# dataset settings
4+
data_root = 'data/coco/'
5+
dataset_type = 'YOLOv5CocoDataset'
6+
7+
# parameters that often need to be modified
8+
img_scale = (640, 640) # height, width
9+
deepen_factor = 1.0
10+
widen_factor = 1.0
11+
max_epochs = 300
12+
save_epoch_intervals = 10
13+
train_batch_size_per_gpu = 16
14+
train_num_workers = 8
15+
val_batch_size_per_gpu = 1
16+
val_num_workers = 2
17+
18+
# persistent_workers must be False if num_workers is 0.
19+
persistent_workers = True
20+
21+
# only on Val
22+
batch_shapes_cfg = dict(
23+
type='BatchShapePolicy',
24+
batch_size=val_batch_size_per_gpu,
25+
img_size=img_scale[0],
26+
size_divisor=32,
27+
extra_pad_ratio=0.5)
28+
29+
# different from yolov5
30+
anchors = [[(12, 16), (19, 36), (40, 28)], [(36, 75), (76, 55), (72, 146)],
31+
[(142, 110), (192, 243), (459, 401)]]
32+
strides = [8, 16, 32]
33+
34+
# single-scale training is recommended to
35+
# be turned on, which can speed up training.
36+
env_cfg = dict(cudnn_benchmark=True)
37+
38+
model = dict(
39+
type='YOLODetector',
40+
data_preprocessor=dict(
41+
type='YOLOv5DetDataPreprocessor',
42+
mean=[0., 0., 0.],
43+
std=[255., 255., 255.],
44+
bgr_to_rgb=True),
45+
backbone=dict(
46+
type='YOLOv7Backbone',
47+
deepen_factor=deepen_factor,
48+
widen_factor=widen_factor,
49+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
50+
act_cfg=dict(type='SiLU', inplace=True)),
51+
neck=dict(
52+
type='YOLOv7PAFPN',
53+
deepen_factor=deepen_factor,
54+
widen_factor=widen_factor,
55+
upsample_feats_cat_first=False,
56+
in_channels=[512, 1024, 1024],
57+
out_channels=[128, 256, 512],
58+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
59+
act_cfg=dict(type='SiLU', inplace=True)),
60+
bbox_head=dict(
61+
type='YOLOv7Head',
62+
head_module=dict(
63+
type='YOLOv5HeadModule',
64+
num_classes=80,
65+
in_channels=[256, 512, 1024],
66+
widen_factor=widen_factor,
67+
featmap_strides=strides,
68+
num_base_priors=3),
69+
prior_generator=dict(
70+
type='mmdet.YOLOAnchorGenerator',
71+
base_sizes=anchors,
72+
strides=strides)),
73+
test_cfg=dict(
74+
multi_label=True,
75+
nms_pre=30000,
76+
score_thr=0.001,
77+
nms=dict(type='nms', iou_threshold=0.65),
78+
max_per_img=300))
79+
80+
test_pipeline = [
81+
dict(
82+
type='LoadImageFromFile',
83+
file_client_args={{_base_.file_client_args}}),
84+
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
85+
dict(
86+
type='LetterResize',
87+
scale=img_scale,
88+
allow_scale_up=False,
89+
pad_val=dict(img=114)),
90+
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
91+
dict(
92+
type='mmdet.PackDetInputs',
93+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
94+
'scale_factor', 'pad_param'))
95+
]
96+
97+
val_dataloader = dict(
98+
batch_size=val_batch_size_per_gpu,
99+
num_workers=val_num_workers,
100+
persistent_workers=persistent_workers,
101+
pin_memory=True,
102+
drop_last=False,
103+
sampler=dict(type='DefaultSampler', shuffle=False),
104+
dataset=dict(
105+
type=dataset_type,
106+
data_root=data_root,
107+
test_mode=True,
108+
data_prefix=dict(img='val2017/'),
109+
ann_file='annotations/instances_val2017.json',
110+
pipeline=test_pipeline,
111+
batch_shapes_cfg=batch_shapes_cfg))
112+
113+
test_dataloader = val_dataloader
114+
115+
val_evaluator = dict(
116+
type='mmdet.CocoMetric',
117+
proposal_nums=(100, 1, 10), # Can be accelerated
118+
ann_file=data_root + 'annotations/instances_val2017.json',
119+
metric='bbox')
120+
test_evaluator = val_evaluator
121+
122+
# train_cfg = dict(
123+
# type='EpochBasedTrainLoop',
124+
# max_epochs=max_epochs,
125+
# val_interval=save_epoch_intervals)
126+
val_cfg = dict(type='ValLoop')
127+
test_cfg = dict(type='TestLoop')
128+
129+
# randomness = dict(seed=1, deterministic=True)

mmyolo/models/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from .csp_darknet import YOLOv5CSPDarknet, YOLOXCSPDarknet
44
from .cspnext import CSPNeXt
55
from .efficient_rep import YOLOv6EfficientRep
6+
from .yolov7_backbone import YOLOv7Backbone
67

78
__all__ = [
89
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep',
9-
'YOLOXCSPDarknet', 'CSPNeXt'
10+
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone'
1011
]
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Tuple, Union
3+
4+
import torch.nn as nn
5+
from mmcv.cnn import ConvModule
6+
from mmdet.utils import ConfigType, OptMultiConfig
7+
8+
from mmyolo.registry import MODELS
9+
from ..layers import ELANBlock, MaxPoolAndStrideConvBlock
10+
from .base_backbone import BaseBackbone
11+
12+
13+
@MODELS.register_module()
14+
class YOLOv7Backbone(BaseBackbone):
15+
"""Backbone used in YOLOv7.
16+
17+
Args:
18+
arch (str): Architecture of YOLOv7, from {P5, P6}.
19+
Defaults to P5.
20+
deepen_factor (float): Depth multiplier, multiply number of
21+
blocks in CSP layer by this amount. Defaults to 1.0.
22+
widen_factor (float): Width multiplier, multiply number of
23+
channels in each layer by this amount. Defaults to 1.0.
24+
out_indices (Sequence[int]): Output from which stages.
25+
Defaults to (2, 3, 4).
26+
frozen_stages (int): Stages to be frozen (stop grad and set eval
27+
mode). -1 means not freezing any parameters. Defaults to -1.
28+
plugins (list[dict]): List of plugins for stages, each dict contains:
29+
30+
- cfg (dict, required): Cfg dict to build plugin.
31+
- stages (tuple[bool], optional): Stages to apply plugin, length
32+
should be same as 'num_stages'.
33+
norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
34+
config norm layer. Defaults to dict(type='BN', requires_grad=True).
35+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
36+
Defaults to dict(type='SiLU').
37+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
38+
freeze running stats (mean and var). Note: Effect on Batch Norm
39+
and its variants only.
40+
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
41+
list[:obj:`ConfigDict`]): Initialization config dict.
42+
"""
43+
44+
# From left to right:
45+
# in_channels, out_channels, ELAN mode
46+
arch_settings = {
47+
'P5': [[64, 128, 'expand_channel_2x'], [256, 512, 'expand_channel_2x'],
48+
[512, 1024, 'expand_channel_2x'],
49+
[1024, 1024, 'no_change_channel']]
50+
}
51+
52+
def __init__(self,
53+
arch: str = 'P5',
54+
plugins: Union[dict, List[dict]] = None,
55+
deepen_factor: float = 1.0,
56+
widen_factor: float = 1.0,
57+
input_channels: int = 3,
58+
out_indices: Tuple[int] = (2, 3, 4),
59+
frozen_stages: int = -1,
60+
norm_cfg: ConfigType = dict(
61+
type='BN', momentum=0.03, eps=0.001),
62+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
63+
norm_eval: bool = False,
64+
init_cfg: OptMultiConfig = None):
65+
super().__init__(
66+
self.arch_settings[arch],
67+
deepen_factor,
68+
widen_factor,
69+
input_channels=input_channels,
70+
out_indices=out_indices,
71+
plugins=plugins,
72+
frozen_stages=frozen_stages,
73+
norm_cfg=norm_cfg,
74+
act_cfg=act_cfg,
75+
norm_eval=norm_eval,
76+
init_cfg=init_cfg)
77+
78+
def build_stem_layer(self) -> nn.Module:
79+
"""Build a stem layer."""
80+
stem = nn.Sequential(
81+
ConvModule(
82+
3,
83+
int(self.arch_setting[0][0] * self.widen_factor // 2),
84+
3,
85+
padding=1,
86+
stride=1,
87+
norm_cfg=self.norm_cfg,
88+
act_cfg=self.act_cfg),
89+
ConvModule(
90+
int(self.arch_setting[0][0] * self.widen_factor // 2),
91+
int(self.arch_setting[0][0] * self.widen_factor),
92+
3,
93+
padding=1,
94+
stride=2,
95+
norm_cfg=self.norm_cfg,
96+
act_cfg=self.act_cfg),
97+
ConvModule(
98+
int(self.arch_setting[0][0] * self.widen_factor),
99+
int(self.arch_setting[0][0] * self.widen_factor),
100+
3,
101+
padding=1,
102+
stride=1,
103+
norm_cfg=self.norm_cfg,
104+
act_cfg=self.act_cfg))
105+
return stem
106+
107+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
108+
"""Build a stage layer.
109+
110+
Args:
111+
stage_idx (int): The index of a stage layer.
112+
setting (list): The architecture setting of a stage layer.
113+
"""
114+
in_channels, out_channels, elan_mode = setting
115+
116+
in_channels = int(in_channels * self.widen_factor)
117+
out_channels = int(out_channels * self.widen_factor)
118+
119+
stage = []
120+
if stage_idx == 0:
121+
pre_layer = ConvModule(
122+
in_channels,
123+
out_channels,
124+
3,
125+
stride=2,
126+
padding=1,
127+
norm_cfg=self.norm_cfg,
128+
act_cfg=self.act_cfg)
129+
elan_layer = ELANBlock(
130+
out_channels,
131+
mode=elan_mode,
132+
num_blocks=2,
133+
norm_cfg=self.norm_cfg,
134+
act_cfg=self.act_cfg)
135+
stage.extend([pre_layer, elan_layer])
136+
else:
137+
pre_layer = MaxPoolAndStrideConvBlock(
138+
in_channels,
139+
mode='reduce_channel_2x',
140+
norm_cfg=self.norm_cfg,
141+
act_cfg=self.act_cfg)
142+
elan_layer = ELANBlock(
143+
in_channels,
144+
mode=elan_mode,
145+
num_blocks=2,
146+
norm_cfg=self.norm_cfg,
147+
act_cfg=self.act_cfg)
148+
stage.extend([pre_layer, elan_layer])
149+
return stage

mmyolo/models/dense_heads/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
33
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
44
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
5+
from .yolov7_head import YOLOv7Head
56
from .yolox_head import YOLOXHead, YOLOXHeadModule
67

78
__all__ = [
89
'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule',
910
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
10-
'RTMDetSepBNHeadModule'
11+
'RTMDetSepBNHeadModule', 'YOLOv7Head'
1112
]

0 commit comments

Comments
 (0)