Skip to content

Commit c7003bb

Browse files
authored
[Fix] Fix CascadeRoIHead export when reg_class_agnostic=True in box_head (#1900)
* fix convnext * fix batch inference * update docs * add regression test config * fix pose_tracker.cpp lint
1 parent d181311 commit c7003bb

File tree

4 files changed

+31
-22
lines changed

4 files changed

+31
-22
lines changed

csrc/mmdeploy/apis/python/pose_tracker.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ std::vector<py::tuple> Apply(mmdeploy::PoseTracker* self,
3030
std::vector<py::tuple> batch_ret;
3131
batch_ret.reserve(frames.size());
3232
for (const auto& rs : results) {
33-
py::array_t<float> keypoints({static_cast<int>(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3});
33+
py::array_t<float> keypoints(
34+
{static_cast<int>(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3});
3435
py::array_t<float> bboxes({static_cast<int>(rs.size()), 4});
3536
py::array_t<uint32_t> track_ids(static_cast<int>(rs.size()));
3637
auto kpts_ptr = keypoints.mutable_data();

docs/en/03-benchmark/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ The table below lists the models that are guaranteed to be exportable to other b
1717
| GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
1818
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
1919
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
20+
| ConvNeXt | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/convnext) |
2021
| Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
2122
| VFNet | MMDetection | N | N | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
2223
| RepPoints | MMDetection | N | N | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |

mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,18 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
4242
'while in exporting to ONNX'
4343
# Remove the scores
4444
rois = proposals[..., :-1]
45-
batch_size = rois.shape[0]
4645
num_proposals_per_img = rois.shape[1]
46+
batch_size = rois.shape[0]
4747
# Eliminate the batch dimension
4848
rois = rois.view(-1, 4)
49+
inds = torch.arange(
50+
batch_size, device=rois.device).float().repeat(num_proposals_per_img,
51+
1)
52+
inds = inds.t().reshape(-1, 1)
53+
rois = torch.cat([inds, rois], dim=1)
4954

50-
# Add dummy batch index
51-
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
52-
53-
max_shape = img_metas[0]['img_shape']
55+
max_shape = None
56+
scale_factor = None
5457
ms_scores = []
5558
rcnn_test_cfg = self.test_cfg
5659

@@ -59,24 +62,19 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
5962

6063
cls_score = bbox_results['cls_score']
6164
bbox_pred = bbox_results['bbox_pred']
62-
# Recover the batch dimension
63-
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
64-
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
65-
cls_score.size(-1))
66-
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
65+
6766
ms_scores.append(cls_score)
6867
if i < self.num_stages - 1:
69-
assert self.bbox_head[i].reg_class_agnostic
70-
new_rois = self.bbox_head[i].bbox_coder.decode(
71-
rois[..., 1:], bbox_pred, max_shape=max_shape)
72-
rois = new_rois.reshape(-1, new_rois.shape[-1])
73-
# Add dummy batch index
74-
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
68+
assert not self.bbox_head[i].custom_activation
69+
bbox_label = cls_score[:, :-1].argmax(dim=1)
70+
rois = self.bbox_head[i].regress_by_class(rois, bbox_label,
71+
bbox_pred, img_metas[0])
7572

7673
cls_score = sum(ms_scores) / float(len(ms_scores))
77-
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
78-
rois = rois.reshape(batch_size, num_proposals_per_img, -1)
79-
scale_factor = img_metas[0].get('scale_factor', None)
74+
cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1))
75+
rois = rois.reshape(batch_size, -1, rois.size(-1))
76+
bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1))
77+
8078
det_bboxes, det_labels = self.bbox_head[-1].get_bboxes(
8179
rois, cls_score, bbox_pred, max_shape, scale_factor, cfg=rcnn_test_cfg)
8280

@@ -85,8 +83,8 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
8583
else:
8684
batch_index = torch.arange(det_bboxes.size(0),
8785
device=det_bboxes.device). \
88-
float().view(-1, 1, 1).expand(
89-
det_bboxes.size(0), det_bboxes.size(1), 1)
86+
float().view(-1, 1, 1).expand(
87+
det_bboxes.size(0), det_bboxes.size(1), 1)
9088
rois = det_bboxes[..., :4]
9189
mask_rois = torch.cat([batch_index, rois], dim=-1)
9290
mask_rois = mask_rois.view(-1, 5)

tests/regression/mmdet.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,12 @@ models:
320320
pipelines:
321321
- *pipeline_seg_ort_dynamic_fp32
322322
- *pipeline_seg_trt_dynamic_fp32
323+
324+
- name: Convnext
325+
metafile: configs/convnext/metafile.yml
326+
model_configs:
327+
- configs/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco.py
328+
pipelines:
329+
- *pipeline_seg_ort_dynamic_fp32
330+
- *pipeline_seg_trt_dynamic_fp32
331+
- *pipeline_seg_openvino_dynamic_fp32

0 commit comments

Comments
 (0)