Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,13 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
if (field_name.compare("mode") == 0) {
int data_size = fc->fields[i].length;
const char *data_start = static_cast<const char *>(fc->fields[i].data);
std::string poolModeStr(data_start, data_size);
if (poolModeStr == "avg") {
std::string pool_mode_str(data_start);
if (pool_mode_str == "avg") {
poolMode = 1;
} else if (poolModeStr == "max") {
} else if (pool_mode_str == "max") {
poolMode = 0;
} else {
std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl;
std::cout << "Unknown pool mode \"" << pool_mode_str << "\"." << std::endl;
}
ASSERT(poolMode >= 0);
}
Expand Down
1 change: 0 additions & 1 deletion mmdeploy/codebase/mmcls/models/utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def shift_window_msa__forward__default(ctx, self, query, hw_shape):
'mmcls.models.utils.ShiftWindowMSA.get_attn_mask',
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
def shift_window_msa__get_attn_mask__default(ctx,
self,
hw_shape,
window_size,
shift_size,
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet3d/core/bbox/fcos3d_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.core.bbox.coders.fcos3d_bbox_coder.FCOS3DBBoxCoder.decode_yaw')
def decode_yaw(ctx, self, bbox, centers2d, dir_cls, dir_offset, cam2img):
def decode_yaw(ctx, bbox, centers2d, dir_cls, dir_offset, cam2img):
"""Decode yaw angle and change it from local to global.i. Rewrite this func
to use slice instead of the original operation.
Args:
Expand Down
12 changes: 11 additions & 1 deletion mmdeploy/core/rewriters/function_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
Tuple, Union)

Expand Down Expand Up @@ -72,7 +73,16 @@ def _set_func(origin_func_path: str,
rewrite_func,
ignore_refs=ignore_refs,
ignore_keys=ignore_keys)
exec(f'{origin_func_path} = rewrite_func')

is_static_method = False
if method_class:
origin_type = inspect.getattr_static(module_or_class, split_path[-1])
is_static_method = isinstance(origin_type, staticmethod)

if is_static_method:
exec(f'{origin_func_path} = staticmethod(rewrite_func)')
else:
exec(f'{origin_func_path} = rewrite_func')


def _del_func(path: str):
Expand Down
23 changes: 23 additions & 0 deletions mmdeploy/core/rewriters/rewriter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,29 @@ def decorator(object):

return decorator

def remove_record(self, object: Any, filter_cb: Optional[Callable] = None):
"""Remove record.

Args:
object (Any): The object to remove.
filter_cb (Callable): Check if the object need to be remove.
Defaults to None.
"""
key_to_pop = []
for key, records in self._rewrite_records.items():
for rec in records:
if rec['_object'] == object:
if filter_cb is not None:
if filter_cb(rec):
continue
key_to_pop.append((key, rec))

for key, rec in key_to_pop:
records = self._rewrite_records[key]
records.remove(rec)
if len(records) == 0:
self._rewrite_records.pop(key)


class ContextCaller:
"""A callable object used in RewriteContext.
Expand Down
37 changes: 26 additions & 11 deletions mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes,
get_ir_config, get_onnx_config)

try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close


def backend_checker(backend: Backend, require_plugin: bool = False):
"""A decorator which checks if a backend is available.
Expand Down Expand Up @@ -189,12 +194,6 @@ def __init__(self, recover_class):
self._recover_class = recover_class

def __enter__(self):
return self

def __exit__(self, type, value, trace):
self.recover()

def set(self, **kwargs):
"""Replace attributes in backend wrappers with dummy items."""
obj = self._recover_class
self.init = obj.__init__
Expand All @@ -203,10 +202,9 @@ def set(self, **kwargs):
obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__
obj.forward = SwitchBackendWrapper.BackendWrapper.forward
obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__
for k, v in kwargs.items():
setattr(obj, k, v)
return self

def recover(self):
def __exit__(self, type, value, trace):
"""Recover to original class."""
assert self.init is not None and \
self.forward is not None,\
Expand All @@ -216,6 +214,11 @@ def recover(self):
obj.forward = self.forward
obj.__call__ = self.call

def set(self, **kwargs):
obj = self._recover_class
for k, v in kwargs.items():
setattr(obj, k, v)


def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
actual: List[Union[torch.Tensor, np.ndarray]],
Expand All @@ -239,8 +242,7 @@ def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
if isinstance(actual[i], (list, np.ndarray)):
actual[i] = torch.tensor(actual[i])
try:
torch.testing.assert_allclose(
actual[i], expected[i], rtol=1e-03, atol=1e-05)
torch_assert_close(actual[i], expected[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
assert '(0.00%)' in str(error), str(error)
Expand Down Expand Up @@ -417,6 +419,19 @@ def get_backend_outputs(ir_file_path: str,
if backend == Backend.TENSORRT:
device = 'cuda'
model_inputs = dict((k, v.cuda()) for k, v in model_inputs.items())
input_shapes = dict(
(k, dict(min_shape=v.shape, max_shape=v.shape, opt_shape=v.shape))
for k, v in model_inputs.items())
model_inputs_cfg = deploy_cfg['backend_config'].get(
'model_inputs', [dict(input_shapes=input_shapes)])
if len(model_inputs_cfg) < 1:
model_inputs_cfg = [dict(input_shapes=input_shapes)]

if 'input_shapes' not in model_inputs_cfg[0]:
model_inputs_cfg[0]['input_shapes'] = input_shapes

deploy_cfg['backend_config']['model_inputs'] = model_inputs_cfg

elif backend == Backend.OPENVINO:
input_info = {
name: value.shape
Expand Down
33 changes: 19 additions & 14 deletions tests/test_apis/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from multiprocessing import Process

import mmcv
import pytest

from mmdeploy.apis import create_calib_input_data

calib_file = tempfile.NamedTemporaryFile(suffix='.h5').name
ann_file = 'tests/data/annotation.json'


def get_end2end_deploy_cfg():
@pytest.fixture
def deploy_cfg():
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
Expand Down Expand Up @@ -53,14 +55,15 @@ def get_end2end_deploy_cfg():
return deploy_cfg


def get_partition_deploy_cfg():
deploy_cfg = get_end2end_deploy_cfg()
@pytest.fixture
def partition_deploy_cfg(deploy_cfg):
deploy_cfg._cfg_dict['partition_config'] = dict(
type='two_stage', apply_marks=True)
return deploy_cfg


def get_model_cfg():
@pytest.fixture
def model_cfg():
dataset_type = 'CustomDataset'
data_root = 'tests/data/'
img_norm_cfg = dict(
Expand Down Expand Up @@ -169,10 +172,8 @@ def get_model_cfg():
return model_cfg


def run_test_create_calib_end2end():
def run_test_create_calib_end2end(deploy_cfg, model_cfg):
import h5py
model_cfg = get_model_cfg()
deploy_cfg = get_end2end_deploy_cfg()
create_calib_input_data(
calib_file,
deploy_cfg,
Expand All @@ -194,18 +195,19 @@ def run_test_create_calib_end2end():
# new process.


def test_create_calib_end2end():
p = Process(target=run_test_create_calib_end2end)
def test_create_calib_end2end(deploy_cfg, model_cfg):
p = Process(
target=run_test_create_calib_end2end,
kwargs=dict(deploy_cfg=deploy_cfg, model_cfg=model_cfg))
try:
p.start()
finally:
p.join()


def run_test_create_calib_parittion():
def run_test_create_calib_parittion(partition_deploy_cfg, model_cfg):
import h5py
model_cfg = get_model_cfg()
deploy_cfg = get_partition_deploy_cfg()
deploy_cfg = partition_deploy_cfg
create_calib_input_data(
calib_file,
deploy_cfg,
Expand All @@ -227,8 +229,11 @@ def run_test_create_calib_parittion():
assert calib_data[partition_name][input_names[i]]['0'] is not None


def test_create_calib_parittion():
p = Process(target=run_test_create_calib_parittion)
def test_create_calib_parittion(partition_deploy_cfg, model_cfg):
p = Process(
target=run_test_create_calib_parittion,
kwargs=dict(
partition_deploy_cfg=partition_deploy_cfg, model_cfg=model_cfg))
try:
p.start()
finally:
Expand Down
1 change: 1 addition & 0 deletions tests/test_backend/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def run_wrapper(backend, wrapper, input):
ALL_BACKEND = list(Backend)
ALL_BACKEND.remove(Backend.DEFAULT)
ALL_BACKEND.remove(Backend.PYTORCH)
ALL_BACKEND.remove(Backend.SNPE)
ALL_BACKEND.remove(Backend.SDK)


Expand Down
19 changes: 19 additions & 0 deletions tests/test_codebase/test_mmcls/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest

from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Codebase


def pytest_ignore_collect(*args, **kwargs):
import importlib
return importlib.util.find_spec('mmcls') is None


@pytest.fixture(autouse=True, scope='package')
def import_all_modules():
codebase = Codebase.MMCLS
try:
import_codebase(codebase)
except ImportError:
pytest.skip(f'{codebase} is not installed.', allow_module_level=True)
Loading