Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions mmdeploy/codebase/mmdet/models/detectors/single_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs):
def __forward_impl(self, batch_inputs, data_samples):
"""Rewrite and adding mark for `forward`.

Encapsulate this function for rewriting `forward` of BaseDetector.
Expand All @@ -25,6 +25,27 @@ def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs):
return output


@torch.fx.wrap
def _set_metainfo(data_samples, img_shape):
"""Set the metainfo.

Code in this function cannot be traced by fx.
"""

# fx can not trace deepcopy correctly
data_samples = copy.deepcopy(data_samples)
if data_samples is None:
data_samples = [DetDataSample()]

# note that we can not use `set_metainfo`, deepcopy would crash the
# onnx trace.
for data_sample in data_samples:
data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')

return data_samples


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.single_stage.SingleStageDetector.forward')
def single_stage_detector__forward(self,
Expand Down Expand Up @@ -53,9 +74,7 @@ def single_stage_detector__forward(self,
(num_instances, ).
"""
ctx = FUNCTION_REWRITER.get_context()
data_samples = copy.deepcopy(data_samples)
if data_samples is None:
data_samples = [DetDataSample()]

deploy_cfg = ctx.cfg

# get origin input shape as tensor to support onnx dynamic shape
Expand All @@ -65,11 +84,6 @@ def single_stage_detector__forward(self,
img_shape = [int(val) for val in img_shape]

# set the metainfo
# note that we can not use `set_metainfo`, deepcopy would crash the
# onnx trace.
for data_sample in data_samples:
data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')
data_samples = _set_metainfo(data_samples, img_shape)

return __forward_impl(
ctx, self, batch_inputs, data_samples=data_samples, **kwargs)
return __forward_impl(self, batch_inputs, data_samples=data_samples)
33 changes: 33 additions & 0 deletions mmdeploy/core/rewriters/function_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import types
from collections import defaultdict
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
Tuple, Union)

from torch.fx._symbolic_trace import _wrapped_fns_to_patch

from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
copy_function, get_frame_func, get_func_qualname,
Expand Down Expand Up @@ -94,6 +97,24 @@ def _del_func(path: str):
continue


def _fx_wrap_copied_fn(func: types.FunctionType,
copied_func: types.FunctionType):
"""If a function is wrapped by torch.fx.wrap, its copy also needs to be
wrapped by torch.fx.wrap."""
if not hasattr(func, '__globals__'):
return

wrapped_fns_globals = [item[0] for item in _wrapped_fns_to_patch]
wrapped_fns_names = [item[1] for item in _wrapped_fns_to_patch]

# check if wrapped by torch.fx.wrap
if func.__globals__ in wrapped_fns_globals:
idx = wrapped_fns_globals.index(func.__globals__)
fn_name = wrapped_fns_names[idx]
# a hacky way to wrap the func in copied func
_wrapped_fns_to_patch.append((copied_func.__globals__, fn_name))


class FunctionRewriter:
"""A function rewriter which maintains rewritten functions.

Expand Down Expand Up @@ -147,6 +168,8 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):
self._func_contexts.clear()
# Get current records
functions_records = self._registry.get_records(env)
# Get current fx wrapped func nums
self._ori_fx_wrap_num = len(_wrapped_fns_to_patch)

self._origin_functions = list()
self._additional_functions = list()
Expand Down Expand Up @@ -186,11 +209,16 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):

# Create context_caller
rewrite_function = record_dict['_object']
# The func before and after copy has different globals
rewrite_function = copy_function(rewrite_function)
extra_kwargs = kwargs.copy()
extra_kwargs.update(record_dict)
context_caller = ContextCaller(rewrite_function, origin_func,
cfg, **extra_kwargs)
# If there is a function wrapped by torch.fx.wrap in
# rewrite_function's globals, we need to wrap the same name
# function in copied function's globals.
_fx_wrap_copied_fn(record_dict['_object'], context_caller.func)

qualname = get_func_qualname(rewrite_function)
self._func_contexts[qualname].append(context_caller)
Expand All @@ -209,6 +237,11 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):

def exit(self):
"""Recover the function rewrite."""
# Restore _wrapped_fns_to_patch
cur_fx_wrap_num = len(_wrapped_fns_to_patch)
for _ in range(cur_fx_wrap_num - self._ori_fx_wrap_num):
_wrapped_fns_to_patch.pop(-1)

for func_dict in self._origin_functions:
func_path = func_dict['func_path']
func = func_dict['origin_func']
Expand Down