Skip to content

Commit 964b47c

Browse files
Qingrennlvhan028
authored andcommitted
CodeCamp #101: Support MMDetection 3.x RTMDet model deployment on RV1126 (#1551)
* * partition rtmdet * * add rtmdet deploy config * * add rtmdet deploy config * * modify rtmdet pipline anchor_generator's info dump * support rtmdet infer in sdk * fix a bug * * fix a bug in csrc/mmdeploy/preprocess/transform/normalize.cpp * * fix a bug * * update docs * * fix lint * * update several urls in docs
1 parent 9d843cd commit 964b47c

File tree

13 files changed

+304
-7
lines changed

13 files changed

+304
-7
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']
2+
3+
onnx_config = dict(input_shape=[640, 640])
4+
5+
codebase_config = dict(model_type='rknn')
6+
7+
backend_config = dict(input_size_list=[[3, 640, 640]])
8+
9+
# rtmdet for rknn-toolkit and rknn-toolkit2
10+
# partition_config = dict(
11+
# type='rknn', # the partition policy name
12+
# apply_marks=True, # should always be set to True
13+
# partition_cfg=[
14+
# dict(
15+
# save_file='model.onnx', # name to save the partitioned onnx
16+
# start=['detector_forward:input'], # [mark_name:input, ...]
17+
# end=['rtmdet_head:output'], # [mark_name:output, ...]
18+
# output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
19+
# ])
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
#include "rtmdet_head.h"
3+
4+
#include <math.h>
5+
6+
#include <algorithm>
7+
#include <numeric>
8+
9+
#include "mmdeploy/core/model.h"
10+
#include "mmdeploy/core/utils/device_utils.h"
11+
#include "mmdeploy/core/utils/formatter.h"
12+
#include "utils.h"
13+
14+
namespace mmdeploy::mmdet {
15+
16+
RTMDetSepBNHead::RTMDetSepBNHead(const Value& cfg) : MMDetection(cfg) {
17+
auto init = [&]() -> Result<void> {
18+
auto model = cfg["context"]["model"].get<Model>();
19+
if (cfg.contains("params")) {
20+
nms_pre_ = cfg["params"].value("nms_pre", -1);
21+
score_thr_ = cfg["params"].value("score_thr", 0.02f);
22+
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0);
23+
max_per_img_ = cfg["params"].value("max_per_img", 100);
24+
iou_threshold_ = cfg["params"].contains("nms")
25+
? cfg["params"]["nms"].value("iou_threshold", 0.45f)
26+
: 0.45f;
27+
if (cfg["params"].contains("anchor_generator")) {
28+
offset_ = cfg["params"]["anchor_generator"].value("offset", 0);
29+
from_value(cfg["params"]["anchor_generator"]["strides"], strides_);
30+
}
31+
}
32+
return success();
33+
};
34+
init().value();
35+
}
36+
37+
Result<Value> RTMDetSepBNHead::operator()(const Value& prep_res, const Value& infer_res) {
38+
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
39+
try {
40+
std::vector<Tensor> cls_scores;
41+
std::vector<Tensor> bbox_preds;
42+
const Device kHost{0, 0};
43+
int i = 0;
44+
int divisor = infer_res.size() / 2;
45+
for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) {
46+
auto pred_map = iter->get<Tensor>();
47+
OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream()));
48+
if (i < divisor)
49+
cls_scores.push_back(_pred_map);
50+
else
51+
bbox_preds.push_back(_pred_map);
52+
i++;
53+
}
54+
OUTCOME_TRY(stream().Wait());
55+
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], bbox_preds, cls_scores));
56+
return to_value(result);
57+
} catch (...) {
58+
return Status(eFail);
59+
}
60+
}
61+
62+
static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); }
63+
64+
Result<Detections> RTMDetSepBNHead::GetBBoxes(const Value& prep_res,
65+
const std::vector<Tensor>& bbox_preds,
66+
const std::vector<Tensor>& cls_scores) const {
67+
MMDEPLOY_DEBUG("bbox_pred: {}, {}", bbox_preds[0].shape(), dets[0].data_type());
68+
MMDEPLOY_DEBUG("cls_score: {}, {}", scores[0].shape(), scores[0].data_type());
69+
70+
std::vector<float> filter_boxes;
71+
std::vector<float> obj_probs;
72+
std::vector<int> class_ids;
73+
74+
for (int i = 0; i < bbox_preds.size(); i++) {
75+
RTMDetFeatDeocde(bbox_preds[i], cls_scores[i], strides_[i], offset_, filter_boxes, obj_probs,
76+
class_ids);
77+
}
78+
79+
std::vector<int> indexArray;
80+
for (int i = 0; i < obj_probs.size(); ++i) {
81+
indexArray.push_back(i);
82+
}
83+
Sort(obj_probs, class_ids, indexArray);
84+
85+
Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT,
86+
TensorShape{int(filter_boxes.size() / 4), 4}, "dets"});
87+
std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>());
88+
NMS(dets, iou_threshold_, indexArray);
89+
90+
Detections objs;
91+
std::vector<float> scale_factor;
92+
if (prep_res.contains("scale_factor")) {
93+
from_value(prep_res["scale_factor"], scale_factor);
94+
} else {
95+
scale_factor = {1.f, 1.f, 1.f, 1.f};
96+
}
97+
int ori_width = prep_res["ori_shape"][2].get<int>();
98+
int ori_height = prep_res["ori_shape"][1].get<int>();
99+
auto det_ptr = dets.data<float>();
100+
for (int i = 0; i < indexArray.size(); ++i) {
101+
if (indexArray[i] == -1) {
102+
continue;
103+
}
104+
int j = indexArray[i];
105+
auto x1 = det_ptr[j * 4 + 0];
106+
auto y1 = det_ptr[j * 4 + 1];
107+
auto x2 = det_ptr[j * 4 + 2];
108+
auto y2 = det_ptr[j * 4 + 3];
109+
int label_id = class_ids[i];
110+
float score = obj_probs[i];
111+
112+
MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);
113+
114+
auto rect =
115+
MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height, 0, 0);
116+
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
117+
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
118+
rect[3] - rect[1]);
119+
continue;
120+
}
121+
Detection det{};
122+
det.index = i;
123+
det.label_id = label_id;
124+
det.score = score;
125+
det.bbox = rect;
126+
objs.push_back(std::move(det));
127+
}
128+
129+
return objs;
130+
}
131+
132+
int RTMDetSepBNHead::RTMDetFeatDeocde(const Tensor& bbox_pred, const Tensor& cls_score,
133+
const float stride, const float offset,
134+
std::vector<float>& filter_boxes,
135+
std::vector<float>& obj_probs,
136+
std::vector<int>& class_ids) const {
137+
int cls_param_num = cls_score.shape(1);
138+
int feat_h = bbox_pred.shape(2);
139+
int feat_w = bbox_pred.shape(3);
140+
int feat_size = feat_h * feat_w;
141+
auto bbox_ptr = bbox_pred.data<float>();
142+
auto score_ptr = cls_score.data<float>(); // (b, c, h, w)
143+
int valid_count = 0;
144+
for (int i = 0; i < feat_h; i++) {
145+
for (int j = 0; j < feat_w; j++) {
146+
float max_score = score_ptr[i * feat_w + j];
147+
int class_id = 0;
148+
for (int k = 0; k < cls_param_num; k++) {
149+
float score = score_ptr[k * feat_size + i * feat_w + j];
150+
if (score > max_score) {
151+
max_score = score;
152+
class_id = k;
153+
}
154+
}
155+
max_score = sigmoid(max_score);
156+
if (max_score < score_thr_) continue;
157+
158+
obj_probs.push_back(max_score);
159+
class_ids.push_back(class_id);
160+
161+
float tl_x = bbox_ptr[0 * feat_size + i * feat_w + j];
162+
float tl_y = bbox_ptr[1 * feat_size + i * feat_w + j];
163+
float br_x = bbox_ptr[2 * feat_size + i * feat_w + j];
164+
float br_y = bbox_ptr[3 * feat_size + i * feat_w + j];
165+
166+
auto box = RTMDetdecode(tl_x, tl_y, br_x, br_y, stride, offset, j, i);
167+
168+
tl_x = box[0];
169+
tl_y = box[1];
170+
br_x = box[2];
171+
br_y = box[3];
172+
173+
filter_boxes.push_back(tl_x);
174+
filter_boxes.push_back(tl_y);
175+
filter_boxes.push_back(br_x);
176+
filter_boxes.push_back(br_y);
177+
valid_count++;
178+
}
179+
}
180+
return valid_count;
181+
}
182+
183+
std::array<float, 4> RTMDetSepBNHead::RTMDetdecode(float tl_x, float tl_y, float br_x, float br_y,
184+
float stride, float offset, int j, int i) const {
185+
tl_x = (offset + j) * stride - tl_x;
186+
tl_y = (offset + i) * stride - tl_y;
187+
br_x = (offset + j) * stride + br_x;
188+
br_y = (offset + i) * stride + br_y;
189+
return std::array<float, 4>{tl_x, tl_y, br_x, br_y};
190+
}
191+
192+
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, RTMDetSepBNHead);
193+
194+
} // namespace mmdeploy::mmdet
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
#ifndef MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_
3+
#define MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_
4+
5+
#include "mmdeploy/codebase/mmdet/mmdet.h"
6+
#include "mmdeploy/core/tensor.h"
7+
8+
namespace mmdeploy::mmdet {
9+
10+
class RTMDetSepBNHead : public MMDetection {
11+
public:
12+
explicit RTMDetSepBNHead(const Value& cfg);
13+
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
14+
Result<Detections> GetBBoxes(const Value& prep_res, const std::vector<Tensor>& bbox_preds,
15+
const std::vector<Tensor>& cls_scores) const;
16+
int RTMDetFeatDeocde(const Tensor& bbox_pred, const Tensor& cls_score, const float stride,
17+
const float offset, std::vector<float>& filter_boxes,
18+
std::vector<float>& obj_probs, std::vector<int>& class_ids) const;
19+
std::array<float, 4> RTMDetdecode(float tl_x, float tl_y, float br_x, float br_y, float stride,
20+
float offset, int j, int i) const;
21+
22+
private:
23+
float score_thr_{0.4f};
24+
int nms_pre_{1000};
25+
float iou_threshold_{0.45f};
26+
int min_bbox_size_{0};
27+
int max_per_img_{100};
28+
float offset_{0.0f};
29+
std::vector<float> strides_;
30+
};
31+
32+
} // namespace mmdeploy::mmdet
33+
34+
#endif // MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_

csrc/mmdeploy/preprocess/transform/normalize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,14 @@ class Normalize : public Transform {
9999
Tensor dst;
100100
if (to_float_) {
101101
OUTCOME_TRY(normalize_.Apply(tensor, dst));
102+
data[key] = std::move(dst);
102103
} else if (to_rgb_) {
103104
auto src_mat = to_mat(tensor, PixelFormat::kBGR);
104105
Mat dst_mat;
105106
OUTCOME_TRY(cvt_color_.Apply(src_mat, dst_mat, PixelFormat::kBGR));
106107
dst = to_tensor(src_mat);
108+
data[key] = std::move(dst);
107109
}
108-
data[key] = std::move(dst);
109110

110111
for (auto& v : mean_) {
111112
data["img_norm_cfg"]["mean"].push_back(v);

docs/en/01-how-to-build/rockchip.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,22 @@ label: 65, score: 0.95
156156
])
157157
```
158158

159+
RTMDet: you may paste the following partition configuration into [detection_rknn-int8_static-640x640.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/configs/mmdet/detection/detection_rknn-int8_static-640x640.py):
160+
161+
```python
162+
# rtmdet for rknn-toolkit and rknn-toolkit2
163+
partition_config = dict(
164+
type='rknn', # the partition policy name
165+
apply_marks=True, # should always be set to True
166+
partition_cfg=[
167+
dict(
168+
save_file='model.onnx', # name to save the partitioned onnx
169+
start=['detector_forward:input'], # [mark_name:input, ...]
170+
end=['rtmdet_head:output'], # [mark_name:output, ...]
171+
output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
172+
])
173+
```
174+
159175
RetinaNet & SSD & FSAF with rknn-toolkit2, you may paste the following partition configuration into [detection_rknn_static-320x320.py](https://github.com/open-mmlab/mmdeploy/tree/1.x/configs/mmdet/detection/detection_rknn_static-320x320.py). Users with rknn-toolkit can directly use default config.
160176

161177
```python

docs/en/02-how-to-run/convert_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Notes:
1111
### Prerequisite
1212

1313
1. Install and build your target backend. You could refer to [ONNXRuntime-install](../05-supported-backends/onnxruntime.md), [TensorRT-install](../05-supported-backends/tensorrt.md), [ncnn-install](../05-supported-backends/ncnn.md), [PPLNN-install](../05-supported-backends/pplnn.md), [OpenVINO-install](../05-supported-backends/openvino.md) for more information.
14-
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/get_started.md#installation), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/2_get_started.md#installation).
14+
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/get_started.md#installation), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/get_started/install.md).
1515

1616
### Usage
1717

docs/en/02-how-to-run/write_config.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,4 @@ detection_tensorrt-int8_dynamic-320x320-1344x1344.py
177177

178178
## 6. How to write model config
179179

180-
According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/user_guides/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md), [MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/1_config.md).
180+
According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/user_guides/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md), [MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/config.md).

docs/zh_cn/01-how-to-build/rockchip.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,24 @@ python tools/deploy.py \
134134

135135
```
136136

137+
- RTMDet
138+
139+
将下面的模型拆分配置写入到 [detection_rknn-int8_static-640x640.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/configs/mmdet/detection/detection_rknn-int8_static-640x640.py)
140+
141+
```python
142+
# rtmdet for rknn-toolkit and rknn-toolkit2
143+
partition_config = dict(
144+
type='rknn', # the partition policy name
145+
apply_marks=True, # should always be set to True
146+
partition_cfg=[
147+
dict(
148+
save_file='model.onnx', # name to save the partitioned onnx
149+
start=['detector_forward:input'], # [mark_name:input, ...]
150+
end=['rtmdet_head:output'], # [mark_name:output, ...]
151+
output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
152+
])
153+
```
154+
137155
- RetinaNet & SSD & FSAF with rknn-toolkit2
138156

139157
将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/1.x/configs/mmdet/detection/detection_rknn_static-320x320.py)。使用 rknn-toolkit 的用户则不用。

docs/zh_cn/02-how-to-run/convert_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
### 准备工作
2727

2828
1. 安装您的目标后端。 您可以参考 [ONNXRuntime-install](../05-supported-backends/onnxruntime.md)[TensorRT-install](../05-supported-backends/tensorrt.md)[ncnn-install](../05-supported-backends/ncnn.md)[PPLNN-install](../05-supported-backends/pplnn.md), [OpenVINO-install](../05-supported-backends/openvino.md)
29-
2. 安装您的目标代码库。 您可以参考 [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/get_started.md#%E5%AE%89%E8%A3%85)[MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/get_started.md)[MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/get_started.md#installation)[MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/zh_cn/get_started/install.md)[MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/2_get_started.md#installation)
29+
2. 安装您的目标代码库。 您可以参考 [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/get_started.md#%E5%AE%89%E8%A3%85)[MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/get_started.md)[MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/get_started.md#installation)[MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/zh_cn/get_started/install.md)[MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/get_started/install.md)
3030

3131
### 使用方法
3232

docs/zh_cn/02-how-to-run/write_config.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,4 @@ detection_tensorrt-int8_dynamic-320x320-1344x1344.py
187187

188188
## 6. 如何编写模型配置文件
189189

190-
请根据模型具体任务的代码库,编写模型配置文件。 模型配置文件用于初始化模型,详情请参考[MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/user_guides/config.md)[MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/user_guides/config.md)[MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/user_guides/1_config.md)[MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md)[MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/1_config.md)
190+
请根据模型具体任务的代码库,编写模型配置文件。 模型配置文件用于初始化模型,详情请参考[MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/user_guides/config.md)[MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/user_guides/config.md)[MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/user_guides/1_config.md)[MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md)[MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/config.md)

0 commit comments

Comments
 (0)