Skip to content

Commit 0b6b9ee

Browse files
sync #1493 to support TorchAllocator as TensorRT Gpu Allocator and fix DCNv2 tensorrt plugin error (#1519)
1 parent 3f261e6 commit 0b6b9ee

File tree

6 files changed

+125
-9
lines changed

6 files changed

+125
-9
lines changed

csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,55 @@ nvinfer1::IPluginV2DynamicExt *ModulatedDeformableConvPluginDynamic::clone() con
4949
return plugin;
5050
}
5151

52+
static const nvinfer1::IDimensionExpr *get_hw(const nvinfer1::IDimensionExpr *input,
53+
const nvinfer1::IDimensionExpr *weight,
54+
const nvinfer1::IDimensionExpr *stride,
55+
const nvinfer1::IDimensionExpr *pad,
56+
const nvinfer1::IDimensionExpr *dilation,
57+
nvinfer1::IExprBuilder &exprBuilder) {
58+
using DimOp = nvinfer1::DimensionOperation;
59+
auto expr_1 = exprBuilder.constant(1);
60+
61+
// d*(w-1)+1
62+
auto kernel_0 = exprBuilder.operation(DimOp::kSUB, *weight, *expr_1);
63+
auto kernel_1 = exprBuilder.operation(DimOp::kPROD, *dilation, *kernel_0);
64+
auto kernel = exprBuilder.operation(DimOp::kSUM, *kernel_1, *expr_1);
65+
66+
// (1+2*p-k)//stride -1
67+
auto out_0 = exprBuilder.operation(DimOp::kSUM, *pad, *pad);
68+
auto out_1 = exprBuilder.operation(DimOp::kSUM, *input, *out_0);
69+
auto out_2 = exprBuilder.operation(DimOp::kSUB, *out_1, *kernel);
70+
auto out_3 = exprBuilder.operation(DimOp::kFLOOR_DIV, *out_2, *stride);
71+
auto out = exprBuilder.operation(DimOp::kSUM, *out_3, *expr_1);
72+
73+
return out;
74+
}
75+
5276
nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions(
5377
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
5478
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
79+
using DimOp = nvinfer1::DimensionOperation;
80+
auto weight_dim = inputs[3].d;
5581
nvinfer1::DimsExprs ret;
5682
ret.nbDims = 4;
5783
ret.d[0] = inputs[0].d[0];
5884
ret.d[1] = inputs[3].d[0];
5985

60-
ret.d[2] = inputs[1].d[2];
61-
ret.d[3] = inputs[1].d[3];
86+
auto input_h = inputs[0].d[2];
87+
auto input_w = inputs[0].d[3];
88+
auto weight_h = weight_dim[2];
89+
auto weight_w = weight_dim[3];
90+
auto dilation_w = exprBuilder.constant(mDilation.d[0]);
91+
auto dilation_h = exprBuilder.constant(mDilation.d[1]);
92+
auto pad_w = exprBuilder.constant(mPadding.d[0]);
93+
auto pad_h = exprBuilder.constant(mPadding.d[1]);
94+
auto stride_w = exprBuilder.constant(mStride.d[0]);
95+
auto stride_h = exprBuilder.constant(mStride.d[1]);
96+
auto expr_1 = exprBuilder.constant(1);
97+
auto expr_2 = exprBuilder.constant(2);
98+
99+
ret.d[2] = get_hw(input_h, weight_h, stride_h, pad_h, dilation_h, exprBuilder);
100+
ret.d[3] = get_hw(input_w, weight_w, stride_w, pad_w, dilation_w, exprBuilder);
62101

63102
return ret;
64103
}
@@ -224,11 +263,11 @@ nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin(
224263
}
225264
std::string field_name(fc->fields[i].name);
226265

227-
if (field_name.compare("deformable_group") == 0) {
266+
if (field_name.compare("deform_groups") == 0) {
228267
deformableGroup = static_cast<const int *>(fc->fields[i].data)[0];
229268
}
230269

231-
if (field_name.compare("group") == 0) {
270+
if (field_name.compare("groups") == 0) {
232271
group = static_cast<const int *>(fc->fields[i].data)[0];
233272
}
234273

csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ void ModulatedDeformConvForwardCUDAKernelLauncher(
8585
scalar_t* columns = (scalar_t*)workspace;
8686

8787
const size_t input_step = channels * height * width;
88-
const size_t offset_step = deformable_group * kernel_h * kernel_w * 2 * height * width;
89-
const size_t mask_step = deformable_group * kernel_h * kernel_w * height * width;
88+
const size_t offset_step = deformable_group * kernel_h * kernel_w * 2 * height_out * width_out;
89+
const size_t mask_step = deformable_group * kernel_h * kernel_w * height_out * width_out;
9090
const size_t out_step = channels_out * height_out * width_out;
9191
const size_t out_group_step = out_step / group;
9292
const size_t col_g_step = channels * kernel_w * kernel_h / group * height_out * width_out;

mmdeploy/backend/tensorrt/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def is_custom_ops_available():
3333

3434
try:
3535
# import wrapper if pytorch is available
36+
from .torch_allocator import TorchAllocator
3637
from .wrapper import TRTWrapper
3738
__all__ += ['TRTWrapper']
39+
__all__ += ['TorchAllocator', 'TRTWrapper']
3840
except Exception:
3941
pass
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import tensorrt as trt
3+
import torch
4+
5+
from mmdeploy.utils import get_root_logger
6+
7+
8+
class TorchAllocator(trt.IGpuAllocator):
9+
"""PyTorch Cuda Allocator Wrapper."""
10+
11+
def __init__(self, device_id: int = 0) -> None:
12+
super().__init__()
13+
14+
self.device_id = device_id
15+
self.mems = set()
16+
17+
def __del__(self):
18+
"""destructor."""
19+
mems = self.mems.copy()
20+
(self.deallocate(mem) for mem in mems)
21+
22+
def allocate(self: trt.IGpuAllocator, size: int, alignment: int,
23+
flags: int) -> int:
24+
"""allocate gpu memory.
25+
26+
Args:
27+
self (trt.IGpuAllocator): gpu allocator
28+
size (int): memory size.
29+
alignment (int): memory alignment.
30+
flags (int): flags.
31+
32+
Returns:
33+
int: memory address.
34+
"""
35+
torch_stream = torch.cuda.current_stream(self.device_id)
36+
logger = get_root_logger()
37+
logger.debug(f'allocate {size} memory with TorchAllocator.')
38+
assert alignment >= 0
39+
if alignment > 0:
40+
size = size | (alignment - 1) + 1
41+
mem = torch.cuda.caching_allocator_alloc(
42+
size, device=self.device_id, stream=torch_stream)
43+
self.mems.add(mem)
44+
return mem
45+
46+
def deallocate(self: trt.IGpuAllocator, memory: int) -> bool:
47+
"""deallocate memory.
48+
49+
Args:
50+
self (trt.IGpuAllocator): gpu allocator
51+
memory (int): memory address.
52+
53+
Returns:
54+
bool: deallocate success.
55+
"""
56+
logger = get_root_logger()
57+
logger.debug(f'deallocate {memory} with TorchAllocator.')
58+
if memory not in self.mems:
59+
return False
60+
61+
torch.cuda.caching_allocator_delete(memory)
62+
self.mems.discard(memory)
63+
return True

mmdeploy/backend/tensorrt/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import re
55
import sys
6-
from typing import Dict, Optional, Sequence, Union
6+
from typing import Any, Dict, Optional, Sequence, Union
77

88
import onnx
99
import tensorrt as trt
@@ -24,17 +24,20 @@ def save(engine: trt.ICudaEngine, path: str) -> None:
2424
f.write(bytearray(engine.serialize()))
2525

2626

27-
def load(path: str) -> trt.ICudaEngine:
27+
def load(path: str, allocator: Optional[Any] = None) -> trt.ICudaEngine:
2828
"""Deserialize TensorRT engine from disk.
2929
3030
Args:
3131
path (str): The disk path to read the engine.
32+
allocator (Any): gpu allocator
3233
3334
Returns:
3435
tensorrt.ICudaEngine: The TensorRT engine loaded from disk.
3536
"""
3637
load_tensorrt_plugin()
3738
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
39+
if allocator is not None:
40+
runtime.gpu_allocator = allocator
3841
with open(path, mode='rb') as f:
3942
engine_bytes = f.read()
4043
trt.init_libnvinfer_plugins(logger, namespace='')
@@ -148,6 +151,9 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
148151
# create builder and network
149152
logger = trt.Logger(log_level)
150153
builder = trt.Builder(logger)
154+
155+
# TODO: use TorchAllocator as builder.gpu_allocator
156+
151157
EXPLICIT_BATCH = 1 << (int)(
152158
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
153159
network = builder.create_network(EXPLICIT_BATCH)

mmdeploy/backend/tensorrt/wrapper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from mmdeploy.utils.timer import TimeCounter
99
from ..base import BACKEND_WRAPPER, BaseWrapper
1010
from .init_plugins import load_tensorrt_plugin
11+
from .torch_allocator import TorchAllocator
1112
from .utils import load
1213

1314

@@ -76,10 +77,12 @@ class TRTWrapper(BaseWrapper):
7677

7778
def __init__(self,
7879
engine: Union[str, trt.ICudaEngine],
79-
output_names: Optional[Sequence[str]] = None):
80+
output_names: Optional[Sequence[str]] = None,
81+
device_id: int = 0):
8082
super().__init__(output_names)
8183
load_tensorrt_plugin()
8284
self.engine = engine
85+
self.allocator = TorchAllocator(device_id)
8386
if isinstance(self.engine, str):
8487
self.engine = load(engine)
8588

@@ -90,6 +93,9 @@ def __init__(self,
9093
self._register_state_dict_hook(TRTWrapper.__on_state_dict)
9194
self.context = self.engine.create_execution_context()
9295

96+
if hasattr(self.context, 'temporary_allocator'):
97+
self.context.temporary_allocator = self.allocator
98+
9399
self.__load_io_names()
94100

95101
def __load_io_names(self):

0 commit comments

Comments
 (0)