From 5db625ff553fabcc86764688baf3a404aa55c5a1 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Fri, 1 Sep 2023 18:03:04 +0800 Subject: [PATCH 1/7] detr batch infer --- .../detection_detr_onnxruntime_dynamic.py | 9 +++ ...detr_tensorrt_dynamic-320x320-1344x1344.py | 22 ++++++ .../codebase/mmdet/deploy/object_detection.py | 11 ++- .../mmdet/deploy/object_detection_model.py | 21 +++++- .../mmdet/models/dense_heads/detr_head.py | 26 +++---- .../mmdet/models/detectors/base_detr.py | 67 +++++++++++++++++-- mmdeploy/utils/config_utils.py | 2 +- tests/regression/mmdet.yml | 21 ++++-- 8 files changed, 152 insertions(+), 27 deletions(-) create mode 100644 configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py create mode 100644 configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py diff --git a/configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py b/configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py new file mode 100644 index 0000000000..aee8990082 --- /dev/null +++ b/configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py @@ -0,0 +1,9 @@ +_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/onnxruntime.py'] +onnx_config = dict( + input_names=['input', 'shape'], + dynamic_axes={ + 'shape': { + 0: 'batch' + }, + }, +) diff --git a/configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py b/configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py new file mode 100644 index 0000000000..bf023726ab --- /dev/null +++ b/configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py @@ -0,0 +1,22 @@ +_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/tensorrt.py'] +onnx_config = dict( + input_names=['input', 'shape'], + dynamic_axes={ + 'shape': { + 0: 'batch', + }, + }, +) + +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + shape=dict( + min_shape=[1, 2], opt_shape=[1, 2], max_shape=[2, 2]), + input=dict( + min_shape=[1, 3, 320, 320], + opt_shape=[1, 3, 800, 1344], + max_shape=[2, 3, 1344, 1344]))) + ]) diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 92152d73cf..1e619273e0 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -12,7 +12,7 @@ from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase from mmdeploy.utils import Backend, Codebase, Task from mmdeploy.utils.config_utils import (get_backend, get_input_shape, - is_dynamic_shape) + is_dynamic_batch, is_dynamic_shape) MMDET_TASK = Registry('mmdet_tasks') @@ -218,7 +218,14 @@ def create_input( data = pseudo_collate(data) if data_preprocessor is not None: data = data_preprocessor(data, False) - return data, data['inputs'] + model_type = self.model_cfg.model.type + inputs = data['inputs'] + if dynamic_flag and is_dynamic_batch( + self.deploy_cfg) and model_type in ['DETR']: + shape_info = torch._shape_as_tensor(inputs)[2:].unsqueeze( + 0).to(torch.long).to(inputs.device) + inputs = (inputs, shape_info) + return data, inputs else: return data, BaseTask.get_tensor_from_input(data) diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index a1e52bffd3..9256b25d0d 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -293,14 +293,26 @@ def forward(self, """ assert mode == 'predict', 'Deploy model only allow mode=="predict".' inputs = inputs.contiguous() - outputs = self.predict(inputs) + ir_config = get_ir_config(self.deploy_cfg) + input_names = ir_config['input_names'] + if len(input_names) == 2 and 'shape' in input_names: + shape_info = [d.img_shape for d in data_samples] + shape_info = torch.tensor( + shape_info, dtype=torch.long, device=inputs.device) + else: + shape_info = None + outputs = self.predict(inputs, shape_info) batch_dets, batch_labels = outputs[:2] batch_masks = outputs[2] if len(outputs) >= 3 else None self.postprocessing_results(batch_dets, batch_labels, batch_masks, data_samples) return data_samples - def predict(self, imgs: Tensor) -> Tuple[np.ndarray, np.ndarray]: + def predict( + self, + imgs: Tensor, + shape_info: Optional[torch.Tensor] = None + ) -> Tuple[np.ndarray, np.ndarray]: """The interface for predict. Args: @@ -310,7 +322,10 @@ def predict(self, imgs: Tensor) -> Tuple[np.ndarray, np.ndarray]: tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5] and class labels of shape [N, num_det]. """ - outputs = self.wrapper({self.input_name: imgs}) + inputs = {self.input_name: imgs} + if shape_info is not None: + inputs['shape'] = shape_info + outputs = self.wrapper(inputs) outputs = self.wrapper.output_to_list(outputs) return outputs diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index 08121bdfbb..f386b4497d 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -8,6 +8,8 @@ from mmdeploy.core import FUNCTION_REWRITER +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.DeformableDETRHead.predict_by_feat') @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.DETRHead.predict_by_feat') def detrhead__predict_by_feat__default(self, @@ -17,10 +19,18 @@ def detrhead__predict_by_feat__default(self, rescale: bool = True): """Rewrite `predict_by_feat` of `FoveaHead` for default backend.""" from mmdet.structures.bbox import bbox_cxcywh_to_xyxy + cls_scores = all_cls_scores_list[-1] bbox_preds = all_bbox_preds_list[-1] + if 'shape_info' in batch_img_metas[0]: + img_shape = batch_img_metas[0]['shape_info'] + else: + img_shape = batch_img_metas[0]['img_shape'] + if isinstance(img_shape, list): + img_shape = torch.tensor( + img_shape, dtype=torch.long, device=cls_scores.device) + img_shape = img_shape.unsqueeze(0) - img_shape = batch_img_metas[0]['img_shape'] max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0])) batch_size = cls_scores.size(0) # `batch_index_offset` is used for the gather of concatenated tensor @@ -49,19 +59,9 @@ def detrhead__predict_by_feat__default(self, ...].squeeze(-1) det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds) - - if isinstance(img_shape, torch.Tensor): - hw = img_shape.flip(0).to(det_bboxes.device) - else: - hw = det_bboxes.new_tensor([img_shape[1], img_shape[0]]) - shape_scale = torch.cat([hw, hw]) - shape_scale = shape_scale.view(1, 1, -1) + shape_scale = img_shape.flip(1).repeat(1, 2).unsqueeze(1) det_bboxes = det_bboxes * shape_scale - # dynamically clip bboxes - x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1) - from mmdeploy.codebase.mmdet.deploy import clip_bboxes - x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, img_shape) - det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1) + det_bboxes.clamp_(min=0) det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1) return det_bboxes, det_labels diff --git a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py index 3531c9183c..7b4b9825af 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py +++ b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py @@ -2,6 +2,7 @@ import copy import torch +import torch.nn.functional as F from mmdet.models.detectors.base import ForwardResults from mmdet.structures import DetDataSample from mmdet.structures.det_data_sample import OptSampleList @@ -47,9 +48,10 @@ def _set_metainfo(data_samples, img_shape): @FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.base_detr.DetectionTransformer.predict') -def detection_transformer__predict(self, + 'mmdet.models.detectors.base_detr.DetectionTransformer.forward') +def detection_transformer__forward(self, batch_inputs: torch.Tensor, + shape_info: torch.Tensor = None, data_samples: OptSampleList = None, rescale: bool = True, **kwargs) -> ForwardResults: @@ -79,11 +81,68 @@ def detection_transformer__predict(self, # get origin input shape as tensor to support onnx dynamic shape is_dynamic_flag = is_dynamic_shape(deploy_cfg) - img_shape = torch._shape_as_tensor(batch_inputs)[2:] + img_shape = torch._shape_as_tensor(batch_inputs)[2:].to( + batch_inputs.device) if not is_dynamic_flag: img_shape = [int(val) for val in img_shape] # set the metainfo data_samples = _set_metainfo(data_samples, img_shape) - + if shape_info is not None: + data_samples[0].set_field( + name='shape_info', value=shape_info, field_type='metainfo') return __predict_impl(self, batch_inputs, data_samples, rescale) + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.detr.DETR.pre_transformer') +def detection_transformer__pre_transformer( + self, img_feats, batch_data_samples: OptSampleList = None): + + feat = img_feats[-1] # NOTE img_feats contains only one feature. + batch_size, feat_dim, _, _ = feat.shape + # construct binary masks which for the transformer. + assert batch_data_samples is not None + # masks = batch_data_samples[0].masks.to(torch.float32) + batch_input_shape = batch_data_samples[0].img_shape + if 'shape_info' in batch_data_samples[0]: + batch_shape_info = batch_data_samples[0].shape_info + masks_h = torch.arange( + batch_input_shape[0], + device=feat.device).reshape(1, -1, + 1).expand(batch_size, -1, + batch_input_shape[1]) + masks_w = torch.arange( + batch_input_shape[1], + device=feat.device).reshape(1, 1, + -1).expand(batch_size, + batch_input_shape[0], -1) + masks_h = masks_h >= batch_shape_info[:, 0].view(-1, 1, 1) + masks_w = masks_w >= batch_shape_info[:, 1].view(-1, 1, 1) + masks = torch.logical_or(masks_h, masks_w).to(torch.float32) + else: + masks = torch.zeros( + batch_size, + batch_input_shape[0], + batch_input_shape[1], + device=feat.device) + + # NOTE following the official DETR repo, non-zero values represent + # ignored positions, while zero values mean valid positions. + + masks = F.interpolate( + masks.unsqueeze(1), size=feat.shape[-2:]).to(torch.bool).squeeze(1) + # [batch_size, embed_dim, h, w] + pos_embed = self.positional_encoding(masks) + + # use `view` instead of `flatten` for dynamically exporting to ONNX + # [bs, c, h, w] -> [bs, h*w, c] + feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1) + # [bs, h, w] -> [bs, h*w] + masks = masks.view(batch_size, -1) + + # prepare transformer_inputs_dict + encoder_inputs_dict = dict(feat=feat, feat_mask=masks, feat_pos=pos_embed) + decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed) + return encoder_inputs_dict, decoder_inputs_dict diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index 5565596fee..6af418421a 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -245,7 +245,7 @@ def is_dynamic_shape(deploy_cfg: Union[str, mmengine.Config], return False # check if 2 (height) and 3 (width) in input axes - if 2 in input_axes and 3 in input_axes: + if 2 in input_axes or 3 in input_axes: return True return False diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 1df7404e5e..1f28530fe3 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -310,11 +310,12 @@ models: model_configs: - configs/detr/detr_r50_8xb2-150e_coco.py pipelines: - - *pipeline_ort_dynamic_fp32 - - deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py + - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py convert_image: *convert_image - backend_test: *default_backend_test - sdk_config: *sdk_dynamic + backend_test: False + - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py + convert_image: *convert_image + backend_test: True - name: CenterNet metafile: configs/centernet/metafile.yml @@ -416,3 +417,15 @@ models: - deploy_config: configs/mmdet/panoptic-seg/panoptic-seg_maskformer_tensorrt_static-800x1344.py convert_image: *convert_image backend_test: *default_backend_test + + - name: DINO + metafile: configs/dino/metafile.yml + model_configs: + - configs/dino/dino-4scale_r50_8xb2-12e_coco.py + pipelines: + - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py + convert_image: *convert_image + backend_test: False + - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py + convert_image: *convert_image + backend_test: True From f97bbff275ba3ca1d213b6b38f87c50c5ebd466d Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Mon, 4 Sep 2023 11:47:50 +0800 Subject: [PATCH 2/7] support dino --- .../codebase/mmdet/deploy/object_detection.py | 3 +- .../mmdet/models/detectors/base_detr.py | 133 +++++++++++++++--- tests/regression/mmdet.yml | 40 +++++- 3 files changed, 156 insertions(+), 20 deletions(-) diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 1e619273e0..32384710a6 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -221,7 +221,8 @@ def create_input( model_type = self.model_cfg.model.type inputs = data['inputs'] if dynamic_flag and is_dynamic_batch( - self.deploy_cfg) and model_type in ['DETR']: + self.deploy_cfg) and ('DETR' in model_type + or model_type == 'DINO'): shape_info = torch._shape_as_tensor(inputs)[2:].unsqueeze( 0).to(torch.long).to(inputs.device) inputs = (inputs, shape_info) diff --git a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py index 7b4b9825af..b89ac8789f 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py +++ b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +from typing import Dict, Tuple import torch import torch.nn.functional as F @@ -94,29 +95,18 @@ def detection_transformer__forward(self, return __predict_impl(self, batch_inputs, data_samples, rescale) -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.detr.DETR.pre_transformer') -def detection_transformer__pre_transformer( - self, img_feats, batch_data_samples: OptSampleList = None): - - feat = img_feats[-1] # NOTE img_feats contains only one feature. - batch_size, feat_dim, _, _ = feat.shape - # construct binary masks which for the transformer. - assert batch_data_samples is not None - # masks = batch_data_samples[0].masks.to(torch.float32) +def _generate_masks(batch_size, batch_data_samples, device): batch_input_shape = batch_data_samples[0].img_shape if 'shape_info' in batch_data_samples[0]: batch_shape_info = batch_data_samples[0].shape_info masks_h = torch.arange( batch_input_shape[0], - device=feat.device).reshape(1, -1, - 1).expand(batch_size, -1, - batch_input_shape[1]) + device=device).reshape(1, -1, 1).expand(batch_size, -1, + batch_input_shape[1]) masks_w = torch.arange( batch_input_shape[1], - device=feat.device).reshape(1, 1, - -1).expand(batch_size, - batch_input_shape[0], -1) + device=device).reshape(1, 1, -1).expand(batch_size, + batch_input_shape[0], -1) masks_h = masks_h >= batch_shape_info[:, 0].view(-1, 1, 1) masks_w = masks_w >= batch_shape_info[:, 1].view(-1, 1, 1) masks = torch.logical_or(masks_h, masks_w).to(torch.float32) @@ -125,7 +115,19 @@ def detection_transformer__pre_transformer( batch_size, batch_input_shape[0], batch_input_shape[1], - device=feat.device) + device=device) + return masks + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.detr.DETR.pre_transformer') +def detr__pre_transformer(self, img_feats, batch_data_samples: OptSampleList): + + feat = img_feats[-1] # NOTE img_feats contains only one feature. + batch_size, feat_dim, _, _ = feat.shape + # construct binary masks which for the transformer. + assert batch_data_samples is not None + masks = _generate_masks(batch_size, batch_data_samples, feat.device) # NOTE following the official DETR repo, non-zero values represent # ignored positions, while zero values mean valid positions. @@ -146,3 +148,100 @@ def detection_transformer__pre_transformer( encoder_inputs_dict = dict(feat=feat, feat_mask=masks, feat_pos=pos_embed) decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed) return encoder_inputs_dict, decoder_inputs_dict + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.deformable_detr.DeformableDETR.pre_transformer') +def deformable_detr__pre_transformer( + self, + mlvl_feats: Tuple[torch.Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict]: + """Process image features before feeding them to the transformer. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + mlvl_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The first dict contains the inputs of encoder and the + second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_encoder()`, which includes 'feat', 'feat_mask', + and 'feat_pos'. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask'. + """ + batch_size = mlvl_feats[0].size(0) + + # construct binary masks for the transformer. + assert batch_data_samples is not None + masks = _generate_masks(batch_size, batch_data_samples, + mlvl_feats[0].device) + # NOTE following the official DETR repo, non-zero values representing + # ignored positions, while zero values means valid positions. + + mlvl_masks = [] + mlvl_pos_embeds = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate(masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) + + feat_flatten = [] + lvl_pos_embed_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + batch_size, c, h, w = feat.shape + spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device) + # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] + feat = feat.view(batch_size, c, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] + mask = mask.flatten(1) + feat_flatten.append(feat) + lvl_pos_embed_flatten.append(lvl_pos_embed) + mask_flatten.append(mask) + spatial_shapes.append(spatial_shape) + + # (bs, num_feat_points, dim) + feat_flatten = torch.cat(feat_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) + mask_flatten = torch.cat(mask_flatten, 1) + + # (num_level, 2) + spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) + level_start_index = torch.cat(( + spatial_shapes.new_zeros((1, )), # (num_level) + spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( # (bs, num_level, 2) + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + encoder_inputs_dict = dict( + feat=feat_flatten, + feat_mask=mask_flatten, + feat_pos=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + decoder_inputs_dict = dict( + memory_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + return encoder_inputs_dict, decoder_inputs_dict diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 1f28530fe3..c3b141ef92 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -312,7 +312,7 @@ models: pipelines: - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py convert_image: *convert_image - backend_test: False + backend_test: True - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py convert_image: *convert_image backend_test: True @@ -425,7 +425,43 @@ models: pipelines: - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py convert_image: *convert_image - backend_test: False + backend_test: True + - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py + convert_image: *convert_image + backend_test: True + + - name: ConditionalDETR + metafile: configs/conditional_detr/metafile.yml + model_configs: + - configs/conditional_detr/conditional-detr_r50_8xb2-50e_coco.py + pipelines: + - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py + convert_image: *convert_image + backend_test: True + - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py + convert_image: *convert_image + backend_test: True + + - name: DAB-DETR + metafile: configs/dab_detr/metafile.yml + model_configs: + - configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py + pipelines: + - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py + convert_image: *convert_image + backend_test: True + - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py + convert_image: *convert_image + backend_test: True + + - name: DeformableDETR + metafile: configs/deformable_detr/metafile.yml + model_configs: + - configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py + pipelines: + - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py + convert_image: *convert_image + backend_test: True - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py convert_image: *convert_image backend_test: True From 5289ce4a4d03664f2f19e3f8d5bcadac70e54268 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Tue, 5 Sep 2023 14:00:20 +0800 Subject: [PATCH 3/7] remove dynamic batch --- .../detection_detr_onnxruntime_dynamic.py | 9 ----- ...detr_tensorrt_dynamic-320x320-1344x1344.py | 22 ---------- .../codebase/mmdet/deploy/object_detection.py | 12 +----- .../mmdet/deploy/object_detection_model.py | 15 +------ .../mmdet/models/dense_heads/detr_head.py | 13 +++--- .../mmdet/models/detectors/base_detr.py | 19 +++++---- tests/regression/mmdet.yml | 40 +++++-------------- 7 files changed, 31 insertions(+), 99 deletions(-) delete mode 100644 configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py delete mode 100644 configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py diff --git a/configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py b/configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py deleted file mode 100644 index aee8990082..0000000000 --- a/configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py +++ /dev/null @@ -1,9 +0,0 @@ -_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/onnxruntime.py'] -onnx_config = dict( - input_names=['input', 'shape'], - dynamic_axes={ - 'shape': { - 0: 'batch' - }, - }, -) diff --git a/configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py b/configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py deleted file mode 100644 index bf023726ab..0000000000 --- a/configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py +++ /dev/null @@ -1,22 +0,0 @@ -_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/tensorrt.py'] -onnx_config = dict( - input_names=['input', 'shape'], - dynamic_axes={ - 'shape': { - 0: 'batch', - }, - }, -) - -backend_config = dict( - common_config=dict(max_workspace_size=1 << 30), - model_inputs=[ - dict( - input_shapes=dict( - shape=dict( - min_shape=[1, 2], opt_shape=[1, 2], max_shape=[2, 2]), - input=dict( - min_shape=[1, 3, 320, 320], - opt_shape=[1, 3, 800, 1344], - max_shape=[2, 3, 1344, 1344]))) - ]) diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 32384710a6..92152d73cf 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -12,7 +12,7 @@ from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase from mmdeploy.utils import Backend, Codebase, Task from mmdeploy.utils.config_utils import (get_backend, get_input_shape, - is_dynamic_batch, is_dynamic_shape) + is_dynamic_shape) MMDET_TASK = Registry('mmdet_tasks') @@ -218,15 +218,7 @@ def create_input( data = pseudo_collate(data) if data_preprocessor is not None: data = data_preprocessor(data, False) - model_type = self.model_cfg.model.type - inputs = data['inputs'] - if dynamic_flag and is_dynamic_batch( - self.deploy_cfg) and ('DETR' in model_type - or model_type == 'DINO'): - shape_info = torch._shape_as_tensor(inputs)[2:].unsqueeze( - 0).to(torch.long).to(inputs.device) - inputs = (inputs, shape_info) - return data, inputs + return data, data['inputs'] else: return data, BaseTask.get_tensor_from_input(data) diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index 9256b25d0d..b562c3f41f 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -293,15 +293,7 @@ def forward(self, """ assert mode == 'predict', 'Deploy model only allow mode=="predict".' inputs = inputs.contiguous() - ir_config = get_ir_config(self.deploy_cfg) - input_names = ir_config['input_names'] - if len(input_names) == 2 and 'shape' in input_names: - shape_info = [d.img_shape for d in data_samples] - shape_info = torch.tensor( - shape_info, dtype=torch.long, device=inputs.device) - else: - shape_info = None - outputs = self.predict(inputs, shape_info) + outputs = self.predict(inputs) batch_dets, batch_labels = outputs[:2] batch_masks = outputs[2] if len(outputs) >= 3 else None self.postprocessing_results(batch_dets, batch_labels, batch_masks, @@ -322,10 +314,7 @@ def predict( tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5] and class labels of shape [N, num_det]. """ - inputs = {self.input_name: imgs} - if shape_info is not None: - inputs['shape'] = shape_info - outputs = self.wrapper(inputs) + outputs = self.wrapper({self.input_name: imgs}) outputs = self.wrapper.output_to_list(outputs) return outputs diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index f386b4497d..6d161b5c25 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -22,14 +22,11 @@ def detrhead__predict_by_feat__default(self, cls_scores = all_cls_scores_list[-1] bbox_preds = all_bbox_preds_list[-1] - if 'shape_info' in batch_img_metas[0]: - img_shape = batch_img_metas[0]['shape_info'] - else: - img_shape = batch_img_metas[0]['img_shape'] - if isinstance(img_shape, list): - img_shape = torch.tensor( - img_shape, dtype=torch.long, device=cls_scores.device) - img_shape = img_shape.unsqueeze(0) + img_shape = batch_img_metas[0]['img_shape'] + if isinstance(img_shape, list): + img_shape = torch.tensor( + img_shape, dtype=torch.long, device=cls_scores.device) + img_shape = img_shape.unsqueeze(0) max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0])) batch_size = cls_scores.size(0) diff --git a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py index b89ac8789f..da5099c748 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py +++ b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py @@ -52,7 +52,6 @@ def _set_metainfo(data_samples, img_shape): 'mmdet.models.detectors.base_detr.DetectionTransformer.forward') def detection_transformer__forward(self, batch_inputs: torch.Tensor, - shape_info: torch.Tensor = None, data_samples: OptSampleList = None, rescale: bool = True, **kwargs) -> ForwardResults: @@ -89,9 +88,6 @@ def detection_transformer__forward(self, # set the metainfo data_samples = _set_metainfo(data_samples, img_shape) - if shape_info is not None: - data_samples[0].set_field( - name='shape_info', value=shape_info, field_type='metainfo') return __predict_impl(self, batch_inputs, data_samples, rescale) @@ -127,7 +123,12 @@ def detr__pre_transformer(self, img_feats, batch_data_samples: OptSampleList): batch_size, feat_dim, _, _ = feat.shape # construct binary masks which for the transformer. assert batch_data_samples is not None - masks = _generate_masks(batch_size, batch_data_samples, feat.device) + batch_input_shape = batch_data_samples[0].img_shape + masks = torch.zeros( + batch_size, + batch_input_shape[0], + batch_input_shape[1], + device=feat.device) # NOTE following the official DETR repo, non-zero values represent # ignored positions, while zero values mean valid positions. @@ -186,8 +187,12 @@ def deformable_detr__pre_transformer( # construct binary masks for the transformer. assert batch_data_samples is not None - masks = _generate_masks(batch_size, batch_data_samples, - mlvl_feats[0].device) + batch_input_shape = batch_data_samples[0].img_shape + masks = torch.zeros( + batch_size, + batch_input_shape[0], + batch_input_shape[1], + device=mlvl_feats[0].device) # NOTE following the official DETR repo, non-zero values representing # ignored positions, while zero values means valid positions. diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index c3b141ef92..f0e813ce8e 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -310,12 +310,8 @@ models: model_configs: - configs/detr/detr_r50_8xb2-150e_coco.py pipelines: - - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py - convert_image: *convert_image - backend_test: True - - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py - convert_image: *convert_image - backend_test: True + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 - name: CenterNet metafile: configs/centernet/metafile.yml @@ -423,45 +419,29 @@ models: model_configs: - configs/dino/dino-4scale_r50_8xb2-12e_coco.py pipelines: - - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py - convert_image: *convert_image - backend_test: True - - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py - convert_image: *convert_image - backend_test: True + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 - name: ConditionalDETR metafile: configs/conditional_detr/metafile.yml model_configs: - configs/conditional_detr/conditional-detr_r50_8xb2-50e_coco.py pipelines: - - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py - convert_image: *convert_image - backend_test: True - - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py - convert_image: *convert_image - backend_test: True + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 - name: DAB-DETR metafile: configs/dab_detr/metafile.yml model_configs: - configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py pipelines: - - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py - convert_image: *convert_image - backend_test: True - - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py - convert_image: *convert_image - backend_test: True + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 - name: DeformableDETR metafile: configs/deformable_detr/metafile.yml model_configs: - configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py pipelines: - - deploy_config: configs/mmdet/detection/detection_detr_onnxruntime_dynamic.py - convert_image: *convert_image - backend_test: True - - deploy_config: configs/mmdet/detection/detection_detr_tensorrt_dynamic-320x320-1344x1344.py - convert_image: *convert_image - backend_test: True + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 From b0c7745faa9660d7d7bd4699f170082a440d9b45 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Tue, 5 Sep 2023 17:19:08 +0800 Subject: [PATCH 4/7] update doc --- docs/en/04-supported-codebases/mmdet.md | 59 ++++++++++++---------- docs/zh_cn/04-supported-codebases/mmdet.md | 59 ++++++++++++---------- 2 files changed, 64 insertions(+), 54 deletions(-) diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index 84e1fe5922..dba7b25d27 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -190,35 +190,40 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter ## Supported models -| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | -| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | -| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y | -| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y | -| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y | -| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | -| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | -| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y | -| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y | -| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y | -| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y | -| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | -| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y | -| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y | -| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y | -| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y | -| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | -| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | -| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y | -| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | -| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | -| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | -| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | +| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | +| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | +| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y | +| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y | +| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y | +| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | +| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | +| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y | +| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y | +| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y | +| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y | +| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | +| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y | +| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y | +| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y | +| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | +| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | +| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | +| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | +| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | +| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | ## Reminder - For transformer based models, strongly suggest use `TensorRT>=8.4`. - Mask2Former should use `TensorRT>=8.6.1` for dynamic shape inference. +- DETR-like models do not support multi-batch inference. diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index 17c501630f..37bfe072a1 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -192,35 +192,40 @@ cv2.imwrite('output_detection.png', img) ## 模型支持列表 -| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | -| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | -| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y | -| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y | -| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y | -| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | -| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | -| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y | -| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y | -| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y | -| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y | -| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | -| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y | -| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y | -| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y | -| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y | -| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | -| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | -| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y | -| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | -| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | -| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | -| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | +| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | +| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | +| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y | +| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y | +| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y | +| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | +| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | +| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y | +| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y | +| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y | +| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y | +| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | +| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y | +| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y | +| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y | +| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | +| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | +| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | +| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | +| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | +| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | ## 注意事项 - 强烈建议使用`TensorRT>=8.4`来转换基于 `transformer` 的模型. - Mask2Former 请使用 `TensorRT>=8.6.1` 以保证动态尺寸正常推理. +- DETR系列模型 不支持多批次推理. From 8afc9607e456d8eca7971e9dbd2cf24df3c37b10 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Wed, 6 Sep 2023 11:38:00 +0800 Subject: [PATCH 5/7] disable exporting masks for image paddings in multi-batch inference --- .../mmdet/models/detectors/__init__.py | 7 +- .../mmdet/models/detectors/base_detr.py | 163 ------------------ .../mmdet/models/detectors/deformable_detr.py | 152 ++++++++++++++++ .../codebase/mmdet/models/detectors/detr.py | 45 +++++ .../codebase/mmdet/models/layers/__init__.py | 1 + .../models/layers/positional_encoding.py | 57 ++++++ 6 files changed, 259 insertions(+), 166 deletions(-) create mode 100644 mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py create mode 100644 mmdeploy/codebase/mmdet/models/detectors/detr.py create mode 100644 mmdeploy/codebase/mmdet/models/layers/positional_encoding.py diff --git a/mmdeploy/codebase/mmdet/models/detectors/__init__.py b/mmdeploy/codebase/mmdet/models/detectors/__init__.py index 460694aa72..7c8cbb7bec 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/__init__.py +++ b/mmdeploy/codebase/mmdet/models/detectors/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import (base_detr, maskformer, panoptic_two_stage_segmentor, - single_stage, single_stage_instance_seg, two_stage) +from . import (base_detr, deformable_detr, detr, maskformer, + panoptic_two_stage_segmentor, single_stage, + single_stage_instance_seg, two_stage) __all__ = [ 'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage', - 'panoptic_two_stage_segmentor', 'maskformer' + 'panoptic_two_stage_segmentor', 'maskformer', 'detr', 'deformable_detr' ] diff --git a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py index da5099c748..42c0cf45f6 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py +++ b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Dict, Tuple import torch -import torch.nn.functional as F from mmdet.models.detectors.base import ForwardResults from mmdet.structures import DetDataSample from mmdet.structures.det_data_sample import OptSampleList @@ -89,164 +87,3 @@ def detection_transformer__forward(self, # set the metainfo data_samples = _set_metainfo(data_samples, img_shape) return __predict_impl(self, batch_inputs, data_samples, rescale) - - -def _generate_masks(batch_size, batch_data_samples, device): - batch_input_shape = batch_data_samples[0].img_shape - if 'shape_info' in batch_data_samples[0]: - batch_shape_info = batch_data_samples[0].shape_info - masks_h = torch.arange( - batch_input_shape[0], - device=device).reshape(1, -1, 1).expand(batch_size, -1, - batch_input_shape[1]) - masks_w = torch.arange( - batch_input_shape[1], - device=device).reshape(1, 1, -1).expand(batch_size, - batch_input_shape[0], -1) - masks_h = masks_h >= batch_shape_info[:, 0].view(-1, 1, 1) - masks_w = masks_w >= batch_shape_info[:, 1].view(-1, 1, 1) - masks = torch.logical_or(masks_h, masks_w).to(torch.float32) - else: - masks = torch.zeros( - batch_size, - batch_input_shape[0], - batch_input_shape[1], - device=device) - return masks - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.detr.DETR.pre_transformer') -def detr__pre_transformer(self, img_feats, batch_data_samples: OptSampleList): - - feat = img_feats[-1] # NOTE img_feats contains only one feature. - batch_size, feat_dim, _, _ = feat.shape - # construct binary masks which for the transformer. - assert batch_data_samples is not None - batch_input_shape = batch_data_samples[0].img_shape - masks = torch.zeros( - batch_size, - batch_input_shape[0], - batch_input_shape[1], - device=feat.device) - - # NOTE following the official DETR repo, non-zero values represent - # ignored positions, while zero values mean valid positions. - - masks = F.interpolate( - masks.unsqueeze(1), size=feat.shape[-2:]).to(torch.bool).squeeze(1) - # [batch_size, embed_dim, h, w] - pos_embed = self.positional_encoding(masks) - - # use `view` instead of `flatten` for dynamically exporting to ONNX - # [bs, c, h, w] -> [bs, h*w, c] - feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1) - pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1) - # [bs, h, w] -> [bs, h*w] - masks = masks.view(batch_size, -1) - - # prepare transformer_inputs_dict - encoder_inputs_dict = dict(feat=feat, feat_mask=masks, feat_pos=pos_embed) - decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed) - return encoder_inputs_dict, decoder_inputs_dict - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.deformable_detr.DeformableDETR.pre_transformer') -def deformable_detr__pre_transformer( - self, - mlvl_feats: Tuple[torch.Tensor], - batch_data_samples: OptSampleList = None) -> Tuple[Dict]: - """Process image features before feeding them to the transformer. - - The forward procedure of the transformer is defined as: - 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' - More details can be found at `TransformerDetector.forward_transformer` - in `mmdet/detector/base_detr.py`. - - Args: - mlvl_feats (tuple[Tensor]): Multi-level features that may have - different resolutions, output from neck. Each feature has - shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. - batch_data_samples (list[:obj:`DetDataSample`], optional): The - batch data samples. It usually includes information such - as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. - Defaults to None. - - Returns: - tuple[dict]: The first dict contains the inputs of encoder and the - second dict contains the inputs of decoder. - - - encoder_inputs_dict (dict): The keyword args dictionary of - `self.forward_encoder()`, which includes 'feat', 'feat_mask', - and 'feat_pos'. - - decoder_inputs_dict (dict): The keyword args dictionary of - `self.forward_decoder()`, which includes 'memory_mask'. - """ - batch_size = mlvl_feats[0].size(0) - - # construct binary masks for the transformer. - assert batch_data_samples is not None - batch_input_shape = batch_data_samples[0].img_shape - masks = torch.zeros( - batch_size, - batch_input_shape[0], - batch_input_shape[1], - device=mlvl_feats[0].device) - # NOTE following the official DETR repo, non-zero values representing - # ignored positions, while zero values means valid positions. - - mlvl_masks = [] - mlvl_pos_embeds = [] - for feat in mlvl_feats: - mlvl_masks.append( - F.interpolate(masks[None], - size=feat.shape[-2:]).to(torch.bool).squeeze(0)) - mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) - - feat_flatten = [] - lvl_pos_embed_flatten = [] - mask_flatten = [] - spatial_shapes = [] - for lvl, (feat, mask, pos_embed) in enumerate( - zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): - batch_size, c, h, w = feat.shape - spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device) - # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] - feat = feat.view(batch_size, c, -1).permute(0, 2, 1) - pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) - lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) - # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] - mask = mask.flatten(1) - feat_flatten.append(feat) - lvl_pos_embed_flatten.append(lvl_pos_embed) - mask_flatten.append(mask) - spatial_shapes.append(spatial_shape) - - # (bs, num_feat_points, dim) - feat_flatten = torch.cat(feat_flatten, 1) - lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) - # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) - mask_flatten = torch.cat(mask_flatten, 1) - - # (num_level, 2) - spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) - level_start_index = torch.cat(( - spatial_shapes.new_zeros((1, )), # (num_level) - spatial_shapes.prod(1).cumsum(0)[:-1])) - valid_ratios = torch.stack( # (bs, num_level, 2) - [self.get_valid_ratio(m) for m in mlvl_masks], 1) - - encoder_inputs_dict = dict( - feat=feat_flatten, - feat_mask=mask_flatten, - feat_pos=lvl_pos_embed_flatten, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - valid_ratios=valid_ratios) - decoder_inputs_dict = dict( - memory_mask=mask_flatten, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - valid_ratios=valid_ratios) - return encoder_inputs_dict, decoder_inputs_dict diff --git a/mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py b/mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py new file mode 100644 index 0000000000..56141279ca --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + +import torch +from mmdet.structures.det_data_sample import OptSampleList +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.deformable_detr.DeformableDETR.pre_transformer') +def deformable_detr__pre_transformer( + self, + mlvl_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict]: + """Rewrite `pre_transformer` for default backend. + + Support exporting without masks for padding info. + + Args: + mlvl_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The first dict contains the inputs of encoder and the + second dict contains the inputs of decoder. + """ + # construct binary masks for the transformer. + assert batch_data_samples is not None + batch_size = mlvl_feats[0].shape[0] + device = mlvl_feats[0].device + mlvl_masks = [] + mlvl_pos_embeds = [] + for feat in mlvl_feats: + mlvl_masks.append(None) + shape_info = dict( + B=batch_size, H=feat.shape[2], W=feat.shape[3], device=device) + mlvl_pos_embeds.append( + self.positional_encoding(mask=None, **shape_info)) + + feat_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + batch_size, c, h, w = feat.shape + spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device) + # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] + feat = feat.view(batch_size, c, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] + feat_flatten.append(feat) + lvl_pos_embed_flatten.append(lvl_pos_embed) + spatial_shapes.append(spatial_shape) + + # (bs, num_feat_points, dim) + feat_flatten = torch.cat(feat_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) + mask_flatten = None + + # (num_level, 2) + spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) + level_start_index = torch.cat(( + spatial_shapes.new_zeros((1, )), # (num_level) + spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.ones( + batch_size, len(mlvl_feats), 2, device=device) # (bs, num_level, 2) + + encoder_inputs_dict = dict( + feat=feat_flatten, + feat_mask=mask_flatten, + feat_pos=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + decoder_inputs_dict = dict( + memory_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + return encoder_inputs_dict, decoder_inputs_dict + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.deformable_detr.' + 'DeformableDETR.gen_encoder_output_proposals') +def deformable_detr__gen_encoder_output_proposals( + self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]: + """Rewrite `gen_encoder_output_proposals` for default backend. + + Support exporting with `memory_mask=None`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + + Returns: + tuple: A tuple of transformed memory and proposals. + + - output_memory (Tensor): The transformed memory for obtaining + top-k proposals, has shape (bs, num_feat_points, dim). + - output_proposals (Tensor): The inverse-normalized proposal, has + shape (batch_size, num_keys, 4) with the last dimension arranged + as (cx, cy, w, h). + """ + assert memory_mask is None, 'only support `memory_mask=None`' + bs = memory.size(0) + proposals = [] + for lvl, HW in enumerate(spatial_shapes): + H, W = HW + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W - 1, W, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(bs, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) + proposals.append(proposal) + output_proposals = torch.cat(proposals, 1) + # do not use `all` to make it exportable to onnx + output_proposals_valid = ((output_proposals > 0.01) & + (output_proposals < 0.99)).sum( + -1, + keepdim=True) == output_proposals.shape[-1] + # inverse_sigmoid + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, + float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.memory_trans_fc(output_memory) + output_memory = self.memory_trans_norm(output_memory) + # [bs, sum(hw), 2] + return output_memory, output_proposals diff --git a/mmdeploy/codebase/mmdet/models/detectors/detr.py b/mmdeploy/codebase/mmdet/models/detectors/detr.py new file mode 100644 index 0000000000..d22fb4aa56 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/detectors/detr.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmdet.structures.det_data_sample import OptSampleList + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.detr.DETR.pre_transformer') +def detr__pre_transformer(self, img_feats, batch_data_samples: OptSampleList): + """Rewrite `pre_transformer` for default backend. + + Support exporting without masks for padding info. + + Args: + img_feats (Tuple[Tensor]): Tuple of features output from the neck, + has shape (bs, c, h, w). + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such as + `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict, dict]: The first dict contains the inputs of encoder + and the second dict contains the inputs of decoder. + """ + feat = img_feats[-1] # NOTE img_feats contains only one feature. + batch_size, feat_dim, h, w = feat.shape + # construct binary masks which for the transformer. + assert batch_data_samples is not None + masks = None # for single image inference + # [batch_size, embed_dim, h, w] + extra_kwargs = dict(B=batch_size, H=h, W=w, device=feat.device) + pos_embed = self.positional_encoding(mask=masks, **extra_kwargs) + + # use `view` instead of `flatten` for dynamically exporting to ONNX + # [bs, c, h, w] -> [bs, h*w, c] + feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1) + # [bs, h, w] -> [bs, h*w] + + # prepare transformer_inputs_dict + encoder_inputs_dict = dict(feat=feat, feat_mask=masks, feat_pos=pos_embed) + decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed) + return encoder_inputs_dict, decoder_inputs_dict diff --git a/mmdeploy/codebase/mmdet/models/layers/__init__.py b/mmdeploy/codebase/mmdet/models/layers/__init__.py index 0559af920f..3bd4dcaa68 100644 --- a/mmdeploy/codebase/mmdet/models/layers/__init__.py +++ b/mmdeploy/codebase/mmdet/models/layers/__init__.py @@ -2,5 +2,6 @@ # recovery for mmyolo from mmdeploy.mmcv.ops import multiclass_nms # noqa: F401, F403 from . import matrix_nms # noqa: F401, F403 +from . import positional_encoding # noqa: F401, F403 __all__ = ['multiclass_nms'] diff --git a/mmdeploy/codebase/mmdet/models/layers/positional_encoding.py b/mmdeploy/codebase/mmdet/models/layers/positional_encoding.py new file mode 100644 index 0000000000..6caa9e1011 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/layers/positional_encoding.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.layers.positional_encoding.' + 'SinePositionalEncoding.forward') +def sine_positional_encoding_forward__default(self, mask: Tensor, + **kwargs) -> Tensor: + """Rewrite `forward` for default backend. + `mask=None` for single image inference + Args: + mask (Tensor | None): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + if mask is not None: + B, H, W = mask.shape + device = mask.device + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + else: + B, H, W, = kwargs['B'], kwargs['H'], kwargs['W'] + device = kwargs['device'] + x_embed = torch.arange(1, W + 1, dtype=torch.float32, device=device) + x_embed = x_embed.view(1, 1, -1).repeat(B, H, 1) + y_embed = torch.arange(1, H + 1, dtype=torch.float32, device=device) + y_embed = y_embed.view(1, -1, 1).repeat(B, 1, W) + + if self.normalize: + y_embed = (y_embed + self.offset) / \ + (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / \ + (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).view(B, H, W, -1) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).view(B, H, W, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos From c973b46d5a5e356bd7488b54845198f9a5ecff51 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Thu, 7 Sep 2023 11:27:13 +0800 Subject: [PATCH 6/7] fix --- mmdeploy/codebase/mmdet/deploy/object_detection_model.py | 6 +----- tests/test_codebase/test_mmdet/test_mmdet_models.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index b562c3f41f..a1e52bffd3 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -300,11 +300,7 @@ def forward(self, data_samples) return data_samples - def predict( - self, - imgs: Tensor, - shape_info: Optional[torch.Tensor] = None - ) -> Tuple[np.ndarray, np.ndarray]: + def predict(self, imgs: Tensor) -> Tuple[np.ndarray, np.ndarray]: """The interface for predict. Args: diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 78b3255b06..ca1c5c1255 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -727,7 +727,7 @@ def test_predict_of_detr_detector(model_cfg_path, backend): from mmdet.structures import DetDataSample data_sample = DetDataSample(metainfo=dict(batch_input_shape=(64, 64))) rewrite_inputs = {'batch_inputs': img} - wrapped_model = WrapModel(model, 'predict', data_samples=[data_sample]) + wrapped_model = WrapModel(model, 'forward', data_samples=[data_sample]) rewrite_outputs, _ = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, From 4d8a6a9c883bc3e2a1a6f493708ba999876b7db0 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Mon, 11 Sep 2023 10:43:27 +0800 Subject: [PATCH 7/7] remove rewriting and move changes to mmdet --- .../mmdet/models/dense_heads/detr_head.py | 2 +- .../mmdet/models/detectors/__init__.py | 13 +- .../mmdet/models/detectors/deformable_detr.py | 152 ------------------ .../codebase/mmdet/models/detectors/detr.py | 45 ------ .../codebase/mmdet/models/layers/__init__.py | 1 - .../models/layers/positional_encoding.py | 57 ------- 6 files changed, 9 insertions(+), 261 deletions(-) delete mode 100644 mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py delete mode 100644 mmdeploy/codebase/mmdet/models/detectors/detr.py delete mode 100644 mmdeploy/codebase/mmdet/models/layers/positional_encoding.py diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index 6d161b5c25..bb2bdee2a8 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -56,9 +56,9 @@ def detrhead__predict_by_feat__default(self, ...].squeeze(-1) det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds) + det_bboxes.clamp_(min=0., max=1.) shape_scale = img_shape.flip(1).repeat(1, 2).unsqueeze(1) det_bboxes = det_bboxes * shape_scale - det_bboxes.clamp_(min=0) det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1) return det_bboxes, det_labels diff --git a/mmdeploy/codebase/mmdet/models/detectors/__init__.py b/mmdeploy/codebase/mmdet/models/detectors/__init__.py index 7c8cbb7bec..ac1d82d7a6 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/__init__.py +++ b/mmdeploy/codebase/mmdet/models/detectors/__init__.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import (base_detr, deformable_detr, detr, maskformer, - panoptic_two_stage_segmentor, single_stage, - single_stage_instance_seg, two_stage) +from . import (base_detr, maskformer, panoptic_two_stage_segmentor, + single_stage, single_stage_instance_seg, two_stage) __all__ = [ - 'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage', - 'panoptic_two_stage_segmentor', 'maskformer', 'detr', 'deformable_detr' + 'base_detr', + 'single_stage', + 'single_stage_instance_seg', + 'two_stage', + 'panoptic_two_stage_segmentor', + 'maskformer', ] diff --git a/mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py b/mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py deleted file mode 100644 index 56141279ca..0000000000 --- a/mmdeploy/codebase/mmdet/models/detectors/deformable_detr.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Tuple - -import torch -from mmdet.structures.det_data_sample import OptSampleList -from torch import Tensor - -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.deformable_detr.DeformableDETR.pre_transformer') -def deformable_detr__pre_transformer( - self, - mlvl_feats: Tuple[Tensor], - batch_data_samples: OptSampleList = None) -> Tuple[Dict]: - """Rewrite `pre_transformer` for default backend. - - Support exporting without masks for padding info. - - Args: - mlvl_feats (tuple[Tensor]): Multi-level features that may have - different resolutions, output from neck. Each feature has - shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. - batch_data_samples (list[:obj:`DetDataSample`], optional): The - batch data samples. It usually includes information such - as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. - Defaults to None. - - Returns: - tuple[dict]: The first dict contains the inputs of encoder and the - second dict contains the inputs of decoder. - """ - # construct binary masks for the transformer. - assert batch_data_samples is not None - batch_size = mlvl_feats[0].shape[0] - device = mlvl_feats[0].device - mlvl_masks = [] - mlvl_pos_embeds = [] - for feat in mlvl_feats: - mlvl_masks.append(None) - shape_info = dict( - B=batch_size, H=feat.shape[2], W=feat.shape[3], device=device) - mlvl_pos_embeds.append( - self.positional_encoding(mask=None, **shape_info)) - - feat_flatten = [] - lvl_pos_embed_flatten = [] - spatial_shapes = [] - for lvl, (feat, mask, pos_embed) in enumerate( - zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): - batch_size, c, h, w = feat.shape - spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device) - # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] - feat = feat.view(batch_size, c, -1).permute(0, 2, 1) - pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) - lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) - # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] - feat_flatten.append(feat) - lvl_pos_embed_flatten.append(lvl_pos_embed) - spatial_shapes.append(spatial_shape) - - # (bs, num_feat_points, dim) - feat_flatten = torch.cat(feat_flatten, 1) - lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) - # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) - mask_flatten = None - - # (num_level, 2) - spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) - level_start_index = torch.cat(( - spatial_shapes.new_zeros((1, )), # (num_level) - spatial_shapes.prod(1).cumsum(0)[:-1])) - valid_ratios = torch.ones( - batch_size, len(mlvl_feats), 2, device=device) # (bs, num_level, 2) - - encoder_inputs_dict = dict( - feat=feat_flatten, - feat_mask=mask_flatten, - feat_pos=lvl_pos_embed_flatten, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - valid_ratios=valid_ratios) - decoder_inputs_dict = dict( - memory_mask=mask_flatten, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - valid_ratios=valid_ratios) - return encoder_inputs_dict, decoder_inputs_dict - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.deformable_detr.' - 'DeformableDETR.gen_encoder_output_proposals') -def deformable_detr__gen_encoder_output_proposals( - self, memory: Tensor, memory_mask: Tensor, - spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]: - """Rewrite `gen_encoder_output_proposals` for default backend. - - Support exporting with `memory_mask=None`. - - Args: - memory (Tensor): The output embeddings of the Transformer encoder, - has shape (bs, num_feat_points, dim). - memory_mask (Tensor): ByteTensor, the padding mask of the memory, - has shape (bs, num_feat_points). - spatial_shapes (Tensor): Spatial shapes of features in all levels, - has shape (num_levels, 2), last dimension represents (h, w). - - Returns: - tuple: A tuple of transformed memory and proposals. - - - output_memory (Tensor): The transformed memory for obtaining - top-k proposals, has shape (bs, num_feat_points, dim). - - output_proposals (Tensor): The inverse-normalized proposal, has - shape (batch_size, num_keys, 4) with the last dimension arranged - as (cx, cy, w, h). - """ - assert memory_mask is None, 'only support `memory_mask=None`' - bs = memory.size(0) - proposals = [] - for lvl, HW in enumerate(spatial_shapes): - H, W = HW - grid_y, grid_x = torch.meshgrid( - torch.linspace( - 0, H - 1, H, dtype=torch.float32, device=memory.device), - torch.linspace( - 0, W - 1, W, dtype=torch.float32, device=memory.device)) - grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) - scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(bs, 1, 1, 2) - grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale - wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) - proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) - proposals.append(proposal) - output_proposals = torch.cat(proposals, 1) - # do not use `all` to make it exportable to onnx - output_proposals_valid = ((output_proposals > 0.01) & - (output_proposals < 0.99)).sum( - -1, - keepdim=True) == output_proposals.shape[-1] - # inverse_sigmoid - output_proposals = torch.log(output_proposals / (1 - output_proposals)) - output_proposals = output_proposals.masked_fill(~output_proposals_valid, - float('inf')) - - output_memory = memory - output_memory = output_memory.masked_fill(~output_proposals_valid, - float(0)) - output_memory = self.memory_trans_fc(output_memory) - output_memory = self.memory_trans_norm(output_memory) - # [bs, sum(hw), 2] - return output_memory, output_proposals diff --git a/mmdeploy/codebase/mmdet/models/detectors/detr.py b/mmdeploy/codebase/mmdet/models/detectors/detr.py deleted file mode 100644 index d22fb4aa56..0000000000 --- a/mmdeploy/codebase/mmdet/models/detectors/detr.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -from mmdet.structures.det_data_sample import OptSampleList - -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.detr.DETR.pre_transformer') -def detr__pre_transformer(self, img_feats, batch_data_samples: OptSampleList): - """Rewrite `pre_transformer` for default backend. - - Support exporting without masks for padding info. - - Args: - img_feats (Tuple[Tensor]): Tuple of features output from the neck, - has shape (bs, c, h, w). - batch_data_samples (List[:obj:`DetDataSample`]): The batch - data samples. It usually includes information such as - `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. - Defaults to None. - - Returns: - tuple[dict, dict]: The first dict contains the inputs of encoder - and the second dict contains the inputs of decoder. - """ - feat = img_feats[-1] # NOTE img_feats contains only one feature. - batch_size, feat_dim, h, w = feat.shape - # construct binary masks which for the transformer. - assert batch_data_samples is not None - masks = None # for single image inference - # [batch_size, embed_dim, h, w] - extra_kwargs = dict(B=batch_size, H=h, W=w, device=feat.device) - pos_embed = self.positional_encoding(mask=masks, **extra_kwargs) - - # use `view` instead of `flatten` for dynamically exporting to ONNX - # [bs, c, h, w] -> [bs, h*w, c] - feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1) - pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1) - # [bs, h, w] -> [bs, h*w] - - # prepare transformer_inputs_dict - encoder_inputs_dict = dict(feat=feat, feat_mask=masks, feat_pos=pos_embed) - decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed) - return encoder_inputs_dict, decoder_inputs_dict diff --git a/mmdeploy/codebase/mmdet/models/layers/__init__.py b/mmdeploy/codebase/mmdet/models/layers/__init__.py index 3bd4dcaa68..0559af920f 100644 --- a/mmdeploy/codebase/mmdet/models/layers/__init__.py +++ b/mmdeploy/codebase/mmdet/models/layers/__init__.py @@ -2,6 +2,5 @@ # recovery for mmyolo from mmdeploy.mmcv.ops import multiclass_nms # noqa: F401, F403 from . import matrix_nms # noqa: F401, F403 -from . import positional_encoding # noqa: F401, F403 __all__ = ['multiclass_nms'] diff --git a/mmdeploy/codebase/mmdet/models/layers/positional_encoding.py b/mmdeploy/codebase/mmdet/models/layers/positional_encoding.py deleted file mode 100644 index 6caa9e1011..0000000000 --- a/mmdeploy/codebase/mmdet/models/layers/positional_encoding.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch import Tensor - -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter( - func_name='mmdet.models.layers.positional_encoding.' - 'SinePositionalEncoding.forward') -def sine_positional_encoding_forward__default(self, mask: Tensor, - **kwargs) -> Tensor: - """Rewrite `forward` for default backend. - `mask=None` for single image inference - Args: - mask (Tensor | None): ByteTensor mask. Non-zero values representing - ignored positions, while zero values means valid positions - for this image. Shape [bs, h, w]. - - Returns: - pos (Tensor): Returned position embedding with shape - [bs, num_feats*2, h, w]. - """ - if mask is not None: - B, H, W = mask.shape - device = mask.device - mask = mask.to(torch.int) - not_mask = 1 - mask # logical_not - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - else: - B, H, W, = kwargs['B'], kwargs['H'], kwargs['W'] - device = kwargs['device'] - x_embed = torch.arange(1, W + 1, dtype=torch.float32, device=device) - x_embed = x_embed.view(1, 1, -1).repeat(B, H, 1) - y_embed = torch.arange(1, H + 1, dtype=torch.float32, device=device) - y_embed = y_embed.view(1, -1, 1).repeat(B, 1, W) - - if self.normalize: - y_embed = (y_embed + self.offset) / \ - (y_embed[:, -1:, :] + self.eps) * self.scale - x_embed = (x_embed + self.offset) / \ - (x_embed[:, :, -1:] + self.eps) * self.scale - dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=device) - dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - # use `view` instead of `flatten` for dynamically exporting to ONNX - - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), - dim=4).view(B, H, W, -1) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), - dim=4).view(B, H, W, -1) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos