Skip to content

Commit e0ed95e

Browse files
authored
[Fix] fix unittest and suppress warning (#1552)
* fix unittest and some warning * fix read string * snake
1 parent 0e65606 commit e0ed95e

File tree

13 files changed

+104
-169
lines changed

13 files changed

+104
-169
lines changed

csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,15 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
201201

202202
if (field_name.compare("mode") == 0) {
203203
int data_size = fc->fields[i].length;
204+
ASSERT(data_size > 0);
204205
const char *data_start = static_cast<const char *>(fc->fields[i].data);
205-
std::string poolModeStr(data_start, data_size);
206-
if (poolModeStr == "avg") {
206+
std::string pool_mode(data_start);
207+
if (pool_mode == "avg") {
207208
poolMode = 1;
208-
} else if (poolModeStr == "max") {
209+
} else if (pool_mode == "max") {
209210
poolMode = 0;
210211
} else {
211-
std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl;
212+
std::cout << "Unknown pool mode \"" << pool_mode << "\"." << std::endl;
212213
}
213214
ASSERT(poolMode >= 0);
214215
}

mmdeploy/backend/tensorrt/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
from .init_plugins import load_tensorrt_plugin
1414

1515

16-
def save(engine: trt.ICudaEngine, path: str) -> None:
16+
def save(engine: Any, path: str) -> None:
1717
"""Serialize TensorRT engine to disk.
1818
1919
Args:
20-
engine (tensorrt.ICudaEngine): TensorRT engine to be serialized.
20+
engine (Any): TensorRT engine to be serialized.
2121
path (str): The absolute disk path to write the engine.
2222
"""
2323
with open(path, mode='wb') as f:
24-
f.write(bytearray(engine.serialize()))
24+
if isinstance(engine, trt.ICudaEngine):
25+
engine = engine.serialize()
26+
f.write(bytearray(engine))
2527

2628

2729
def load(path: str, allocator: Optional[Any] = None) -> trt.ICudaEngine:
@@ -226,7 +228,10 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
226228
builder.int8_calibrator = config.int8_calibrator
227229

228230
# create engine
229-
engine = builder.build_engine(network, config)
231+
if hasattr(builder, 'build_serialized_network'):
232+
engine = builder.build_serialized_network(network, config)
233+
else:
234+
engine = builder.build_engine(network, config)
230235

231236
assert engine is not None, 'Failed to create TensorRT engine'
232237

mmdeploy/codebase/mmdet/deploy/object_detection_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ class labels of shape [N, num_det].
601601
scores = out[:, :, 1:2]
602602
boxes = out[:, :, 2:6] * scales
603603
dets = torch.cat([boxes, scores], dim=2)
604-
return dets, torch.tensor(labels, dtype=torch.int32)
604+
return dets, labels.to(torch.int32)
605605

606606

607607
@__BACKEND_MODEL.register_module('sdk')

mmdeploy/codebase/mmdet/models/backbones.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def focus__forward__ncnn(self, x):
4646

4747
x = x.reshape(batch_size, c * h, 1, w)
4848
_b, _c, _h, _w = x.shape
49-
g = _c // 2
49+
g = torch.div(_c, 2, rounding_mode='floor')
5050
# fuse to ncnn's shufflechannel
5151
x = x.view(_b, g, 2, _h, _w)
5252
x = torch.transpose(x, 1, 2).contiguous()
@@ -55,13 +55,14 @@ def focus__forward__ncnn(self, x):
5555
x = x.reshape(_b, c * h * w, 1, 1)
5656

5757
_b, _c, _h, _w = x.shape
58-
g = _c // 2
58+
g = torch.div(_c, 2, rounding_mode='floor')
5959
# fuse to ncnn's shufflechannel
6060
x = x.view(_b, g, 2, _h, _w)
6161
x = torch.transpose(x, 1, 2).contiguous()
6262
x = x.view(_b, -1, _h, _w)
6363

64-
x = x.reshape(_b, c * 4, h // 2, w // 2)
64+
x = x.reshape(_b, c * 4, torch.div(h, 2, rounding_mode='floor'),
65+
torch.div(w, 2, rounding_mode='floor'))
6566

6667
return self.conv(x)
6768

@@ -198,8 +199,12 @@ def shift_window_msa__forward__default(self, query, hw_shape):
198199
[query,
199200
query.new_zeros(B, C, self.window_size, query.shape[-1])],
200201
dim=-2)
201-
slice_h = (H + self.window_size - 1) // self.window_size * self.window_size
202-
slice_w = (W + self.window_size - 1) // self.window_size * self.window_size
202+
slice_h = torch.div(
203+
(H + self.window_size - 1), self.window_size,
204+
rounding_mode='floor') * self.window_size
205+
slice_w = torch.div(
206+
(W + self.window_size - 1), self.window_size,
207+
rounding_mode='floor') * self.window_size
203208
query = query[:, :, :slice_h, :slice_w]
204209
query = query.permute(0, 2, 3, 1).contiguous()
205210
H_pad, W_pad = query.shape[1], query.shape[2]

mmdeploy/core/rewriters/rewriter_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,29 @@ def decorator(object):
328328

329329
return decorator
330330

331+
def remove_record(self, object: Any, filter_cb: Optional[Callable] = None):
332+
"""Remove record.
333+
334+
Args:
335+
object (Any): The object to remove.
336+
filter_cb (Callable): Check if the object need to be remove.
337+
Defaults to None.
338+
"""
339+
key_to_pop = []
340+
for key, records in self._rewrite_records.items():
341+
for rec in records:
342+
if rec['_object'] == object:
343+
if filter_cb is not None:
344+
if filter_cb(rec):
345+
continue
346+
key_to_pop.append((key, rec))
347+
348+
for key, rec in key_to_pop:
349+
records = self._rewrite_records[key]
350+
records.remove(rec)
351+
if len(records) == 0:
352+
self._rewrite_records.pop(key)
353+
331354

332355
class ContextCaller:
333356
"""A callable object used in RewriteContext.

mmdeploy/mmcv/ops/nms.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def multiclass_nms(boxes: Tensor,
511511

512512

513513
@FUNCTION_REWRITER.register_rewriter(
514-
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
514+
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms',
515515
backend=Backend.COREML.value)
516516
def multiclass_nms__coreml(boxes: Tensor,
517517
scores: Tensor,
@@ -574,8 +574,7 @@ def _xywh2xyxy(boxes):
574574

575575

576576
@FUNCTION_REWRITER.register_rewriter(
577-
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
578-
ir=IR.TORCHSCRIPT)
577+
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', ir=IR.TORCHSCRIPT)
579578
def multiclass_nms__torchscript(boxes: Tensor,
580579
scores: Tensor,
581580
max_output_boxes_per_class: int = 1000,
@@ -676,8 +675,7 @@ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class,
676675

677676

678677
@FUNCTION_REWRITER.register_rewriter(
679-
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
680-
backend='ascend')
678+
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend='ascend')
681679
def multiclass_nms__ascend(boxes: Tensor,
682680
scores: Tensor,
683681
max_output_boxes_per_class: int = 1000,

mmdeploy/utils/test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from mmengine.model import BaseModel
1515
from torch import nn
1616

17+
try:
18+
from torch.testing import assert_close as torch_assert_close
19+
except Exception:
20+
from torch.testing import assert_allclose as torch_assert_close
21+
1722
import mmdeploy.codebase # noqa: F401,F403
1823
from mmdeploy.core import RewriterContext, patch_model
1924
from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes,
@@ -293,8 +298,7 @@ def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
293298
if isinstance(actual[i], (list, np.ndarray)):
294299
actual[i] = torch.tensor(actual[i])
295300
try:
296-
torch.testing.assert_allclose(
297-
actual[i], expected[i], rtol=1e-03, atol=1e-05)
301+
torch_assert_close(actual[i], expected[i], rtol=1e-03, atol=1e-05)
298302
except AssertionError as error:
299303
if tolerate_small_mismatch:
300304
assert '(0.00%)' in str(error), str(error)

tests/test_codebase/test_mmdet/test_mmdet_models.py

Lines changed: 12 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
import numpy as np
1010
import pytest
1111
import torch
12+
13+
try:
14+
from torch.testing import assert_close as torch_assert_close
15+
except Exception:
16+
from torch.testing import assert_allclose as torch_assert_close
17+
1218
from mmengine import Config
1319
from mmengine.config import ConfigDict
1420

@@ -237,7 +243,7 @@ def single_level_grid_priors(input):
237243
# test forward
238244
with RewriterContext({}, backend_type):
239245
wrap_output = wrapped_func(x)
240-
torch.testing.assert_allclose(output, wrap_output)
246+
torch_assert_close(output, wrap_output)
241247

242248
onnx_prefix = tempfile.NamedTemporaryFile().name
243249

@@ -341,23 +347,6 @@ def get_ssd_head_model():
341347
return model
342348

343349

344-
def get_fcos_head_model():
345-
"""FCOS Head Config."""
346-
test_cfg = Config(
347-
dict(
348-
deploy_nms_pre=0,
349-
min_bbox_size=0,
350-
score_thr=0.05,
351-
nms=dict(type='nms', iou_threshold=0.5),
352-
max_per_img=100))
353-
354-
from mmdet.models.dense_heads import FCOSHead
355-
model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
356-
357-
model.requires_grad_(False)
358-
return model
359-
360-
361350
def get_focus_backbone_model():
362351
"""Backbone Focus Config."""
363352
from mmdet.models.backbones.csp_darknet import Focus
@@ -412,10 +401,8 @@ def get_reppoints_head_model():
412401

413402
def get_detrhead_model():
414403
"""DETR head Config."""
415-
from mmdet.models import build_head
416-
from mmdet.utils import register_all_modules
417-
register_all_modules()
418-
model = build_head(
404+
from mmdet.registry import MODELS
405+
model = MODELS.build(
419406
dict(
420407
type='DETRHead',
421408
num_classes=4,
@@ -431,8 +418,7 @@ def get_detrhead_model():
431418
dict(
432419
type='MultiheadAttention',
433420
embed_dims=4,
434-
num_heads=1,
435-
dropout=0.1)
421+
num_heads=1)
436422
],
437423
ffn_cfgs=dict(
438424
type='FFN',
@@ -442,8 +428,6 @@ def get_detrhead_model():
442428
ffn_drop=0.,
443429
act_cfg=dict(type='ReLU', inplace=True),
444430
),
445-
feedforward_channels=32,
446-
ffn_dropout=0.1,
447431
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
448432
decoder=dict(
449433
type='DetrTransformerDecoder',
@@ -454,8 +438,7 @@ def get_detrhead_model():
454438
attn_cfgs=dict(
455439
type='MultiheadAttention',
456440
embed_dims=4,
457-
num_heads=1,
458-
dropout=0.1),
441+
num_heads=1),
459442
ffn_cfgs=dict(
460443
type='FFN',
461444
embed_dims=4,
@@ -465,7 +448,6 @@ def get_detrhead_model():
465448
act_cfg=dict(type='ReLU', inplace=True),
466449
),
467450
feedforward_channels=32,
468-
ffn_dropout=0.1,
469451
operation_order=('self_attn', 'norm', 'cross_attn',
470452
'norm', 'ffn', 'norm')),
471453
)),
@@ -536,7 +518,7 @@ def test_focus_forward(backend_type):
536518
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
537519
model_output = model_output.squeeze()
538520
rewrite_output = rewrite_output.squeeze()
539-
torch.testing.assert_allclose(
521+
torch_assert_close(
540522
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
541523

542524

@@ -578,77 +560,6 @@ def test_l2norm_forward(backend_type):
578560
model_output[0], rewrite_output, rtol=1e-03, atol=1e-05)
579561

580562

581-
def test_predict_by_feat_of_fcos_head_ncnn():
582-
backend_type = Backend.NCNN
583-
check_backend(backend_type)
584-
fcos_head = get_fcos_head_model()
585-
fcos_head.cpu().eval()
586-
s = 128
587-
batch_img_metas = [{
588-
'scale_factor': np.ones(4),
589-
'pad_shape': (s, s, 3),
590-
'img_shape': (s, s, 3)
591-
}]
592-
593-
output_names = ['detection_output']
594-
deploy_cfg = Config(
595-
dict(
596-
backend_config=dict(type=backend_type.value),
597-
onnx_config=dict(output_names=output_names, input_shape=None),
598-
codebase_config=dict(
599-
type='mmdet',
600-
task='ObjectDetection',
601-
model_type='ncnn_end2end',
602-
post_processing=dict(
603-
score_threshold=0.05,
604-
iou_threshold=0.5,
605-
max_output_boxes_per_class=200,
606-
pre_top_k=5000,
607-
keep_top_k=100,
608-
background_label_id=-1,
609-
))))
610-
611-
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
612-
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
613-
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
614-
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
615-
seed_everything(1234)
616-
cls_score = [
617-
torch.rand(1, fcos_head.num_classes, pow(2, i), pow(2, i))
618-
for i in range(5, 0, -1)
619-
]
620-
seed_everything(5678)
621-
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
622-
623-
seed_everything(9101)
624-
centernesses = [
625-
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
626-
]
627-
628-
# to get outputs of onnx model after rewrite
629-
batch_img_metas[0]['img_shape'] = torch.Tensor([s, s])
630-
wrapped_model = WrapModel(
631-
fcos_head,
632-
'predict_by_feat',
633-
batch_img_metas=batch_img_metas,
634-
with_nms=True)
635-
rewrite_inputs = {
636-
'cls_scores': cls_score,
637-
'bbox_preds': bboxes,
638-
'centernesses': centernesses
639-
}
640-
rewrite_outputs, is_backend_output = get_rewrite_outputs(
641-
wrapped_model=wrapped_model,
642-
model_inputs=rewrite_inputs,
643-
deploy_cfg=deploy_cfg)
644-
645-
# output should be of shape [1, N, 6]
646-
if is_backend_output:
647-
assert rewrite_outputs[0].shape[-1] == 6
648-
else:
649-
assert rewrite_outputs.shape[-1] == 6
650-
651-
652563
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
653564
def test_predict_by_feat_of_rpn_head(backend_type: Backend):
654565
check_backend(backend_type)

tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def setup_class(cls):
5757
deploy_cfg=deploy_cfg,
5858
model_cfg=model_cfg)
5959

60+
@classmethod
61+
def teardown_class(cls):
62+
cls.wrapper.recover()
63+
6064
@pytest.mark.skipif(
6165
reason='Only support GPU test',
6266
condition=not torch.cuda.is_available())

0 commit comments

Comments
 (0)