Skip to content

Commit 8e2f655

Browse files
authored
rewrite torch.cat for TensorRT when input is dynamic (#1851)
1 parent 847a906 commit 8e2f655

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

mmdeploy/pytorch/functions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from . import adaptive_pool # noqa: F401,F403
33
from . import any # noqa: F401,F403
44
from . import atan2 # noqa: F401,F403
5+
from . import cat # noqa: F401,F403
56
from . import chunk # noqa: F401,F403
67
from . import clip # noqa: F401,F403
78
from . import expand # noqa: F401,F403

mmdeploy/pytorch/functions/cat.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Sequence
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from mmdeploy.core import FUNCTION_REWRITER
8+
from mmdeploy.utils import get_dynamic_axes
9+
10+
11+
@FUNCTION_REWRITER.register_rewriter(func_name='torch.cat', backend='tensorrt')
12+
def cat__tensorrt(tensors: Sequence[Tensor], *args, **kwargs) -> torch.Tensor:
13+
"""Rewrite `cat` for TensorRT backend.
14+
15+
cat in TensorRT does not support bool or uint8 type when input is dynamic.
16+
"""
17+
ctx = FUNCTION_REWRITER.get_context()
18+
if get_dynamic_axes(ctx.cfg) is None:
19+
return ctx.origin_func(tensors, *args, **kwargs)
20+
if len(tensors) > 0 and (tensors[0].dtype in [torch.bool, torch.uint8]):
21+
original_dtype = tensors[0].dtype
22+
tensors = [i.to(torch.int32) for i in tensors]
23+
return ctx.origin_func(tensors, *args, **kwargs).to(original_dtype)
24+
return ctx.origin_func(tensors, *args, **kwargs)

tests/test_pytorch/test_pytorch_functions.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
codebase_config=dict(type='mmdet', task='ObjectDetection')))
2020

2121

22-
def get_trt_config(output_names, shape):
22+
def get_trt_config(output_names, shape, dynamic_axes=None):
2323
deploy_cfg_tensorrt = Config(
2424
dict(
25-
onnx_config=dict(input_shape=None, output_names=output_names),
25+
onnx_config=dict(
26+
input_shape=None,
27+
output_names=output_names,
28+
dynamic_axes=dynamic_axes),
2629
backend_config=dict(
2730
type='tensorrt',
2831
common_config=dict(
@@ -615,3 +618,30 @@ def linspace_caller(*arg, **kwargs):
615618

616619
assert np.allclose(
617620
model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)
621+
622+
623+
@backend_checker(Backend.TENSORRT)
624+
@pytest.mark.parametrize('dtype', [torch.bool, torch.float32])
625+
@pytest.mark.parametrize('dynamic_axes',
626+
[None, dict(input=dict({
627+
0: 'dim0',
628+
1: 'dim1'
629+
}))])
630+
def test_cat__tensorrt(dtype, dynamic_axes):
631+
input = torch.rand(2, 4)
632+
model = WrapFunction(lambda input: torch.cat(
633+
[input.to(dtype), input.to(dtype)], -1))
634+
pytorch_output = model(input)
635+
rewrite_output, _ = get_rewrite_outputs(
636+
model,
637+
model_inputs={'input': input},
638+
deploy_cfg=get_trt_config(['output'],
639+
shape=[2, 4],
640+
dynamic_axes=dynamic_axes),
641+
run_with_backend=True)
642+
assert pytorch_output.dtype == rewrite_output[0].dtype
643+
assert torch.allclose(
644+
pytorch_output.cpu().float(),
645+
rewrite_output[0].cpu().float(),
646+
rtol=1e-3,
647+
atol=1e-5)

0 commit comments

Comments
 (0)