Skip to content

Commit 50bd7e9

Browse files
authored
support multi_scale_deform_attn trt plugin (#1844)
* support multi_scale_deform_attn trt plugin * fix lint error * add onnx symblic fun * init msdeformablecrossattn symblic fun & fix plugin * unittest for ms_deformable_cross_attn * fix template * update unittest * fix input contiguous of trtwrapper * update doc description
1 parent 8e2f655 commit 50bd7e9

File tree

10 files changed

+715
-0
lines changed

10 files changed

+715
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright (c) OpenMMLab. All rights reserved
2+
#include "trt_ms_deform_attn.hpp"
3+
4+
#include <assert.h>
5+
6+
#include <chrono>
7+
8+
#include "trt_ms_deform_attn_kernel.hpp"
9+
#include "trt_serialize.hpp"
10+
11+
using namespace nvinfer1;
12+
13+
namespace mmdeploy {
14+
namespace {
15+
static const char *PLUGIN_VERSION{"1"};
16+
static const char *PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"};
17+
} // namespace
18+
19+
MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string &name)
20+
: TRTPluginBase(name) {}
21+
22+
MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name,
23+
const void *data,
24+
size_t length)
25+
: TRTPluginBase(name) {}
26+
MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {}
27+
28+
nvinfer1::IPluginV2DynamicExt *MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT {
29+
MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(mLayerName);
30+
plugin->setPluginNamespace(getPluginNamespace());
31+
32+
return plugin;
33+
}
34+
35+
nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions(
36+
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
37+
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
38+
nvinfer1::DimsExprs ret;
39+
ret.nbDims = 3;
40+
ret.d[0] = inputs[0].d[0];
41+
ret.d[1] = inputs[3].d[1];
42+
43+
ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
44+
*inputs[0].d[2], *inputs[0].d[3]);
45+
46+
return ret;
47+
}
48+
49+
bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination(
50+
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {
51+
52+
if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR)
53+
{
54+
if ((pos == 1) || (pos == 2))
55+
{
56+
return (ioDesc[pos].type == nvinfer1::DataType::kINT32);
57+
}
58+
else
59+
{
60+
return ((ioDesc[pos].type == ioDesc[0].type) &&
61+
((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || (ioDesc[pos].type == nvinfer1::DataType::kHALF)));
62+
}
63+
}
64+
else
65+
{
66+
return false;
67+
}
68+
}
69+
70+
void MultiScaleDeformableAttnPluginDynamic::configurePlugin(
71+
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
72+
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT {
73+
}
74+
75+
size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize(
76+
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
77+
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT {
78+
return 0;
79+
}
80+
81+
int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
82+
const nvinfer1::PluginTensorDesc *outputDesc,
83+
const void *const *inputs, void *const *outputs,
84+
void *workSpace,
85+
cudaStream_t stream) TRT_NOEXCEPT {
86+
int32_t const batch = inputDesc[0].dims.d[0];
87+
int32_t spatial_size = inputDesc[0].dims.d[1];
88+
int32_t num_heads = inputDesc[0].dims.d[2];
89+
int32_t channels = inputDesc[0].dims.d[3];
90+
int32_t num_levels = inputDesc[1].dims.d[0];
91+
int32_t num_query = inputDesc[3].dims.d[1];
92+
int32_t num_point = inputDesc[3].dims.d[4];
93+
int32_t rc = 0;
94+
if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
95+
{
96+
float const* value = static_cast<float const*>(inputs[0]);
97+
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
98+
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
99+
float const* samplingLoc = static_cast<float const*>(inputs[3]);
100+
float const* attnWeight = static_cast<float const*>(inputs[4]);
101+
float* output = static_cast<float*>(outputs[0]);
102+
103+
rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
104+
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream);
105+
}
106+
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
107+
{
108+
const __half* value = static_cast<const __half*>(inputs[0]);
109+
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
110+
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
111+
const __half* samplingLoc = static_cast<const __half*>(inputs[3]);
112+
const __half* attnWeight = static_cast<const __half*>(inputs[4]);
113+
__half* output = static_cast<__half*>(outputs[0]);
114+
115+
rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
116+
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream);
117+
}
118+
119+
return rc;
120+
}
121+
122+
nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType(
123+
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT {
124+
return inputTypes[0];
125+
}
126+
127+
// IPluginV2 Methods
128+
const char *MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT {
129+
return PLUGIN_NAME;
130+
}
131+
132+
const char *MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT {
133+
return PLUGIN_VERSION;
134+
}
135+
136+
int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
137+
138+
size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
139+
return 0;
140+
}
141+
142+
void MultiScaleDeformableAttnPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {}
143+
144+
void MultiScaleDeformableAttnPluginDynamic::attachToContext(
145+
cudnnContext *cudnnContext, cublasContext *cublasContext,
146+
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {}
147+
148+
void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {}
149+
150+
////////////////////// creator /////////////////////////////
151+
152+
MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() {
153+
mPluginAttributes.clear();
154+
mFC.nbFields = mPluginAttributes.size();
155+
mFC.fields = mPluginAttributes.data();
156+
}
157+
158+
const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
159+
return PLUGIN_NAME;
160+
}
161+
162+
const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT {
163+
return PLUGIN_VERSION;
164+
}
165+
166+
nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::createPlugin(
167+
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
168+
169+
MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(name);
170+
plugin->setPluginNamespace(getPluginNamespace());
171+
return plugin;
172+
}
173+
174+
nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin(
175+
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT {
176+
auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength);
177+
plugin->setPluginNamespace(getPluginNamespace());
178+
return plugin;
179+
}
180+
REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator);
181+
} // namespace mmdeploy
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
#ifndef TRT_MS_DEFORM_ATTN_HPP
3+
#define TRT_MS_DEFORM_ATTN_HPP
4+
#include <cublas_v2.h>
5+
6+
#include <memory>
7+
#include <string>
8+
#include <vector>
9+
10+
#include "trt_plugin_base.hpp"
11+
12+
namespace mmdeploy {
13+
class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase {
14+
public:
15+
16+
MultiScaleDeformableAttnPluginDynamic(const std::string &name);
17+
18+
MultiScaleDeformableAttnPluginDynamic(const std::string name, const void *data, size_t length);
19+
20+
MultiScaleDeformableAttnPluginDynamic();
21+
22+
~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override;
23+
24+
// IPluginV2DynamicExt Methods
25+
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
26+
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
27+
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
28+
TRT_NOEXCEPT override;
29+
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
30+
int nbOutputs) TRT_NOEXCEPT override;
31+
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
32+
const nvinfer1::DynamicPluginTensorDesc *out,
33+
int nbOutputs) TRT_NOEXCEPT override;
34+
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
35+
const nvinfer1::PluginTensorDesc *outputs,
36+
int nbOutputs) const TRT_NOEXCEPT override;
37+
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
38+
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
39+
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
40+
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
41+
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override;
42+
void detachFromContext() TRT_NOEXCEPT override;
43+
44+
// IPluginV2Ext Methods
45+
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
46+
int nbInputs) const TRT_NOEXCEPT override;
47+
48+
// IPluginV2 Methods
49+
const char *getPluginType() const TRT_NOEXCEPT override;
50+
const char *getPluginVersion() const TRT_NOEXCEPT override;
51+
int getNbOutputs() const TRT_NOEXCEPT override;
52+
size_t getSerializationSize() const TRT_NOEXCEPT override;
53+
void serialize(void *buffer) const TRT_NOEXCEPT override;
54+
};
55+
56+
class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase {
57+
public:
58+
MultiScaleDeformableAttnPluginDynamicCreator();
59+
60+
const char *getPluginName() const TRT_NOEXCEPT override;
61+
62+
const char *getPluginVersion() const TRT_NOEXCEPT override;
63+
64+
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
65+
TRT_NOEXCEPT override;
66+
67+
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
68+
size_t serialLength) TRT_NOEXCEPT override;
69+
};
70+
} // namespace mmdeploy
71+
#endif // TRT_MS_DEFORM_ATTN_HPP
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright (c) OpenMMLab. All rights reserved
2+
#include <assert.h>
3+
#include <cuda_fp16.h>
4+
5+
#include "common_cuda_helper.hpp"
6+
#include "trt_ms_deform_attn_kernel.cuh"
7+
#include "trt_ms_deform_attn_kernel.hpp"
8+
#include "trt_plugin_helper.hpp"
9+
10+
template <typename scalar_t>
11+
void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, int32_t const* dataSpatialShapes,
12+
int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight,
13+
int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, int32_t const numLevels,
14+
int32_t const numQuery, int32_t const numPoint, scalar_t* dataCol)
15+
{
16+
int32_t const numKernels = batchSize * numQuery * numHeads * channels;
17+
int32_t const numActualKernels = batchSize * numQuery * numHeads * channels;
18+
19+
ms_deformable_im2col_gpu_kernel<scalar_t><<<GET_BLOCKS(numActualKernels), THREADS_PER_BLOCK, 0, stream>>>(
20+
numKernels, dataValue, dataSpatialShapes, dataLevelStartIndex, dataSamplingLoc, dataAttnWeight, batchSize,
21+
spatialSize, numHeads, channels, numLevels, numQuery, numPoint, dataCol);
22+
}
23+
24+
25+
template <typename scalar_t>
26+
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes,
27+
const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch,
28+
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
29+
cudaStream_t stream)
30+
{
31+
auto perValueSize = mSpatialSize * mNumHeads * mChannels;
32+
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2;
33+
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint;
34+
auto perOutputSize = mNumQuery * mNumHeads * mChannels;
35+
36+
int32_t mIm2colStep = batch;
37+
38+
for (int32_t n = 0; n < batch / mIm2colStep; ++n)
39+
{
40+
auto columns = output + n * mIm2colStep * perOutputSize;
41+
ms_deformable_im2col_cuda<scalar_t>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
42+
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep,
43+
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns);
44+
}
45+
46+
return 0;
47+
}
48+
49+
template int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes,
50+
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,
51+
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
52+
cudaStream_t stream);
53+
54+
template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes,
55+
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
56+
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
57+
cudaStream_t stream);

0 commit comments

Comments
 (0)