Skip to content

Commit 847a906

Browse files
LKJackyliukai
andauthored
support mmrazor (#1701)
* support mmrazor * add make divisible * update * Pruning -> ModelCompress and add docstring --------- Co-authored-by: liukai <[email protected]>
1 parent 637958a commit 847a906

File tree

5 files changed

+150
-2
lines changed

5 files changed

+150
-2
lines changed

mmdeploy/apis/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def build_task_processor(model_cfg: mmengine.Config,
4141
BaseTask: A task processor.
4242
"""
4343
check_backend_device(deploy_cfg=deploy_cfg, device=device)
44-
codebase_type = get_codebase(deploy_cfg)
44+
codebase_type = get_codebase(deploy_cfg, model_cfg=model_cfg)
4545
custom_module_list = get_codebase_external_module(deploy_cfg)
4646
import_codebase(codebase_type, custom_module_list)
4747
codebase = get_codebase_class(codebase_type)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .mmrazor import MMCodebase, MMRazor
3+
4+
__all__ = ['MMRazor', 'MMCodebase']
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import copy
3+
from typing import Dict, Optional, Tuple, Union
4+
5+
import numpy as np
6+
import torch
7+
from mmengine import Config
8+
from mmengine.model import BaseDataPreprocessor
9+
from mmengine.registry import Registry
10+
11+
from mmdeploy.apis.utils import build_task_processor
12+
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
13+
from mmdeploy.utils import Codebase, Task
14+
15+
MMRAZOR_TASK = Registry('mmrazor_tasks')
16+
17+
18+
@CODEBASE.register_module(Codebase.MMRAZOR.value)
19+
class MMRazor(MMCodebase):
20+
"""MMRazor codebase class."""
21+
task_registry = MMRAZOR_TASK
22+
23+
@classmethod
24+
def register_deploy_modules(cls):
25+
"""Register all rewriters for mmrazor."""
26+
pass
27+
28+
@classmethod
29+
def register_all_modules(cls):
30+
"""Register all related modules and rewriters for mmrazor."""
31+
from mmrazor.utils import register_all_modules
32+
register_all_modules(True)
33+
34+
@classmethod
35+
def build_task_processor(cls, model_cfg: Config, deploy_cfg: Config,
36+
device: str):
37+
"""Build task processor for mmrazor.
38+
39+
Now we use ModelCompress by default.
40+
"""
41+
return ModelCompress(
42+
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
43+
44+
45+
@MMRAZOR_TASK.register_module(Task.ModelCompress.value)
46+
class ModelCompress(BaseTask):
47+
"""General model compress task for mmrazor.
48+
49+
Args:
50+
model_cfg (Config): Original PyTorch model config file
51+
deploy_cfg (Config): Deployment config file or loaded Config
52+
object.
53+
device (str): A string represents device type.
54+
experiment_name (str, optional): Name of current experiment.
55+
If not specified, timestamp will be used as
56+
``experiment_name``. Defaults to ``None``.
57+
"""
58+
59+
def __init__(self,
60+
model_cfg: Config,
61+
deploy_cfg: Config,
62+
device: str,
63+
experiment_name: str = 'BaseTask'):
64+
65+
super().__init__(model_cfg, deploy_cfg, device, experiment_name)
66+
self.origin_model_cfg = self.revert_model_cfg(model_cfg)
67+
self.base_task = build_task_processor(self.origin_model_cfg,
68+
deploy_cfg, device)
69+
70+
def revert_model_cfg(self, model_cfg: Config):
71+
"""Restore the original model config from the model config of the
72+
compressed model."""
73+
origin_model_cfg = copy.deepcopy(model_cfg)
74+
model = model_cfg['model']
75+
if 'architecture' in model:
76+
origin_model = model['architecture']
77+
elif 'algorithm' in model:
78+
origin_model = model['algorithm']['architecture']
79+
else:
80+
raise NotImplementedError()
81+
origin_model_cfg['model'] = origin_model
82+
if 'data_preprocessor' in origin_model:
83+
origin_model_cfg['data_preprocessor'] = origin_model[
84+
'data_preprocessor']
85+
return origin_model_cfg
86+
87+
# abstract method
88+
89+
def build_backend_model(self,
90+
model_files=None,
91+
data_preprocessor_updater=None,
92+
**kwargs) -> torch.nn.Module:
93+
"""Build backend model for using base task."""
94+
return self.base_task.build_backend_model(model_files,
95+
data_preprocessor_updater,
96+
**kwargs)
97+
98+
def create_input(self,
99+
imgs: Union[str, np.ndarray],
100+
input_shape=None,
101+
data_preprocessor: Optional[BaseDataPreprocessor] = None,
102+
**kwargs) -> Tuple[Dict, torch.Tensor]:
103+
"""Create input using base task."""
104+
return self.base_task.create_input(imgs, input_shape,
105+
data_preprocessor, **kwargs)
106+
107+
def get_model_name(self, *args, **kwargs) -> str:
108+
"""Get model name using base task."""
109+
return self.base_task.get_model_name(*args, **kwargs)
110+
111+
def get_preprocess(self, *args, **kwargs) -> Dict:
112+
"""Get data preprocess name using base task."""
113+
return self.base_task.get_preprocess(*args, **kwargs)
114+
115+
def get_postprocess(self, *args, **kwargs) -> Dict:
116+
"""Get data poseprocess name using base task."""
117+
return self.base_task.get_postprocess(*args, **kwargs)
118+
119+
@staticmethod
120+
def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
121+
"""Get a certain partition config."""
122+
raise NotImplementedError()
123+
124+
def build_pytorch_model(self,
125+
model_checkpoint: Optional[str] = None,
126+
cfg_options: Optional[Dict] = None,
127+
**kwargs) -> torch.nn.Module:
128+
"""Build PyTorch model for mmrazor and execute post process for
129+
mmdeploy."""
130+
model = super().build_pytorch_model(model_checkpoint, cfg_options,
131+
**kwargs)
132+
if hasattr(model, 'post_process_for_mmdeploy'):
133+
model.post_process_for_mmdeploy()
134+
135+
return model

mmdeploy/utils/config_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def register_codebase(codebase: str) -> Codebase:
8383
return Codebase.get(codebase)
8484

8585

86-
def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
86+
def get_codebase(deploy_cfg: Union[str, mmengine.Config],
87+
model_cfg=None) -> Codebase:
8788
"""Get the codebase from the config.
8889
8990
Args:
@@ -92,6 +93,12 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
9293
Returns:
9394
Codebase : An enumeration denotes the codebase type.
9495
"""
96+
if model_cfg is not None:
97+
# using mmrazor codebase if the model is a mmrazor model.
98+
model_cfg: dict = model_cfg['model']
99+
if model_cfg.get('_scope_', None) == 'mmrazor'\
100+
or model_cfg['type'].startswith('mmrazor.'):
101+
return register_codebase('mmrazor')
95102
codebase_config = get_codebase_config(deploy_cfg)
96103
assert 'type' in codebase_config, 'The codebase config of deploy config'\
97104
'requires a "type" field'

mmdeploy/utils/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class Task(AdvancedEnum):
2828
POSE_DETECTION = 'PoseDetection'
2929
ROTATED_DETECTION = 'RotatedDetection'
3030
VIDEO_RECOGNITION = 'VideoRecognition'
31+
ModelCompress = 'ModelCompress'
3132

3233

3334
class Codebase(AdvancedEnum):
@@ -41,6 +42,7 @@ class Codebase(AdvancedEnum):
4142
MMPOSE = 'mmpose'
4243
MMROTATE = 'mmrotate'
4344
MMACTION = 'mmaction'
45+
MMRAZOR = 'mmrazor'
4446

4547

4648
class IR(AdvancedEnum):

0 commit comments

Comments
 (0)