Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions configs/mmdet/_base_/base_openvino_dynamic-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = ['./base_dynamic.py', '../../_base_/backends/openvino.py']

onnx_config = dict(input_shape=None)

backend_config = dict(
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))])
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_base_ = ['../_base_/base_openvino_dynamic-640x640.py']
2 changes: 2 additions & 0 deletions csrc/mmdeploy/codebase/mmcls/linear_cls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class LinearClsHead : public MMClassification {
};

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, LinearClsHead);
using ConformerHead = LinearClsHead;
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, ConformerHead);

class CropBox {
public:
Expand Down
29 changes: 2 additions & 27 deletions mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,6 @@
from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.forward_single')
def detrhead__forward_single__default(self, x, img_metas):
"""forward_single of DETRHead.

Ease the mask computation
"""

batch_size = x.size(0)

x = self.input_proj(x)
# interpolate masks to have the same spatial shape with x
masks = x.new_zeros((batch_size, x.size(-2), x.size(-1))).to(torch.bool)

# position encoding
pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
# outs_dec: [nb_dec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
pos_embed)
all_cls_scores = self.fc_cls(outs_dec)
all_bbox_preds = self.fc_reg(self.activate(
self.reg_ffn(outs_dec))).sigmoid()
return all_cls_scores, all_bbox_preds


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.predict_by_feat')
def detrhead__predict_by_feat__default(self,
Expand All @@ -42,8 +17,8 @@ def detrhead__predict_by_feat__default(self,
rescale: bool = True):
"""Rewrite `predict_by_feat` of `FoveaHead` for default backend."""
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
cls_scores = all_cls_scores_list[-1][-1]
bbox_preds = all_bbox_preds_list[-1][-1]
cls_scores = all_cls_scores_list[-1]
bbox_preds = all_bbox_preds_list[-1]

img_shape = batch_img_metas[0]['img_shape']
max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0]))
Expand Down
6 changes: 4 additions & 2 deletions mmdeploy/codebase/mmdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import single_stage, single_stage_instance_seg, two_stage
from . import base_detr, single_stage, single_stage_instance_seg, two_stage

__all__ = ['single_stage', 'single_stage_instance_seg', 'two_stage']
__all__ = [
'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage'
]
89 changes: 89 additions & 0 deletions mmdeploy/codebase/mmdet/models/detectors/base_detr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch
from mmdet.models.detectors.base import ForwardResults
from mmdet.structures import DetDataSample
from mmdet.structures.det_data_sample import OptSampleList

from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils import is_dynamic_shape


@mark('detr_predict', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __predict_impl(self, batch_inputs, data_samples, rescale):
"""Rewrite and adding mark for `predict`.

Encapsulate this function for rewriting `predict` of DetectionTransformer.
1. Add mark for DetectionTransformer.
2. Support both dynamic and static export to onnx.
"""
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats, data_samples)
results_list = self.bbox_head.predict(
**head_inputs_dict, rescale=rescale, batch_data_samples=data_samples)
return results_list


@torch.fx.wrap
def _set_metainfo(data_samples, img_shape):
"""Set the metainfo.

Code in this function cannot be traced by fx.
"""

# fx can not trace deepcopy correctly
data_samples = copy.deepcopy(data_samples)
if data_samples is None:
data_samples = [DetDataSample()]

# note that we can not use `set_metainfo`, deepcopy would crash the
# onnx trace.
for data_sample in data_samples:
data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')

return data_samples


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base_detr.DetectionTransformer.predict')
def detection_transformer__predict(self,
batch_inputs: torch.Tensor,
data_samples: OptSampleList = None,
rescale: bool = True,
**kwargs) -> ForwardResults:
"""Rewrite `predict` for default backend.

Support configured dynamic/static shape for model input and return
detection result as Tensor instead of numpy array.

Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (Boolean): rescale result or not.

Returns:
tuple[Tensor]: Detection results of the
input images.
- dets (Tensor): Classification bboxes and scores.
Has a shape (num_instances, 5)
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
"""
ctx = FUNCTION_REWRITER.get_context()

deploy_cfg = ctx.cfg

# get origin input shape as tensor to support onnx dynamic shape
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
img_shape = torch._shape_as_tensor(batch_inputs)[2:]
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]

# set the metainfo
data_samples = _set_metainfo(data_samples, img_shape)

return __predict_impl(self, batch_inputs, data_samples, rescale)
2 changes: 1 addition & 1 deletion mmdeploy/pytorch/functions/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def interpolate__tensorrt(
size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int,
int]]] = None,
scale_factor: Optional[Union[float, Tuple[float]]] = None,
mode: str = 'bilinear',
mode: str = 'nearest',
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
):
Expand Down
11 changes: 8 additions & 3 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ models:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32
- *pipeline_ncnn_static_fp32
- *pipeline_openvino_dynamic_fp32
- deploy_config: configs/mmdet/detection/detection_openvino_dynamic-640x640.py
convert_image: *convert_image
backend_test: False

- name: Faster R-CNN
metafile: configs/faster_rcnn/metafile.yml
Expand Down Expand Up @@ -298,7 +300,10 @@ models:
- configs/detr/detr_r50_8xb2-150e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16
- deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic

- name: CenterNet
metafile: configs/centernet/metafile.yml
Expand Down Expand Up @@ -335,7 +340,7 @@ models:
- configs/rtmdet/rtmdet_s_8xb32-300e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- deploy_config: configs/mmdet/detection/detection_tensorrt_static-640x640.py
- deploy_config: configs/mmdet/detection/detection_tensorrt_dynamic-64x64-800x800.py
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic
Expand Down
8 changes: 8 additions & 0 deletions tests/test_codebase/test_mmcls/test_mmcls_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def get_invertedresidual_model():
return model


def get_fcuup_model():
from mmcls.models.backbones.conformer import FCUUp
model = FCUUp(16, 16, 16)

model.requires_grad_(False)
return model


def get_vit_backbone():
from mmcls.models.classifiers.image import ImageClassifier
model = ImageClassifier(
Expand Down
129 changes: 129 additions & 0 deletions tests/test_codebase/test_mmdet/data/detr_model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
{
"type": "DETR",
"num_queries": 100,
"data_preprocessor": {
"type": "DetDataPreprocessor",
"mean": [123.675, 116.28, 103.53],
"std": [58.395, 57.12, 57.375],
"bgr_to_rgb": true,
"pad_size_divisor": 1
},
"backbone": {
"type": "ResNet",
"depth": 50,
"num_stages": 4,
"out_indices": [3],
"frozen_stages": 1,
"norm_cfg": {
"type": "BN",
"requires_grad": false
},
"norm_eval": true,
"style": "pytorch",
"init_cfg": {
"type": "Pretrained",
"checkpoint": "torchvision://resnet50"
}
},
"neck": {
"type": "ChannelMapper",
"in_channels": [2048],
"kernel_size": 1,
"out_channels": 256,
"num_outs": 1
},
"encoder": {
"num_layers": 6,
"layer_cfg": {
"self_attn_cfg": {
"embed_dims": 256,
"num_heads": 8,
"dropout": 0.1,
"batch_first": true
},
"ffn_cfg": {
"embed_dims": 256,
"feedforward_channels": 2048,
"num_fcs": 2,
"ffn_drop": 0.1,
"act_cfg": {
"type": "ReLU",
"inplace": true
}
}
}
},
"decoder": {
"num_layers": 6,
"layer_cfg": {
"self_attn_cfg": {
"embed_dims": 256,
"num_heads": 8,
"dropout": 0.1,
"batch_first": true
},
"cross_attn_cfg": {
"embed_dims": 256,
"num_heads": 8,
"dropout": 0.1,
"batch_first": true
},
"ffn_cfg": {
"embed_dims": 256,
"feedforward_channels": 2048,
"num_fcs": 2,
"ffn_drop": 0.1,
"act_cfg": {
"type": "ReLU",
"inplace": true
}
}
},
"return_intermediate": true
},
"positional_encoding": {
"num_feats": 128,
"normalize": true
},
"bbox_head": {
"type": "DETRHead",
"num_classes": 80,
"embed_dims": 256,
"loss_cls": {
"type": "CrossEntropyLoss",
"bg_cls_weight": 0.1,
"use_sigmoid": false,
"loss_weight": 1.0,
"class_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 5.0
},
"loss_iou": {
"type": "GIoULoss",
"loss_weight": 2.0
}
},
"train_cfg": {
"assigner": {
"type":
"HungarianAssigner",
"match_costs": [{
"type": "ClassificationCost",
"weight": 1.0
}, {
"type": "BBoxL1Cost",
"weight": 5.0,
"box_format": "xywh"
}, {
"type": "IoUCost",
"iou_mode": "giou",
"weight": 2.0
}]
}
},
"test_cfg": {
"max_per_img": 100
}
}
Loading