|
| 1 | +// Copyright (c) OpenMMLab. All rights reserved |
| 2 | +#include "trt_ms_deform_attn.hpp" |
| 3 | + |
| 4 | +#include <assert.h> |
| 5 | + |
| 6 | +#include <chrono> |
| 7 | + |
| 8 | +#include "trt_ms_deform_attn_kernel.hpp" |
| 9 | +#include "trt_serialize.hpp" |
| 10 | + |
| 11 | +using namespace nvinfer1; |
| 12 | + |
| 13 | +namespace mmdeploy { |
| 14 | +namespace { |
| 15 | +static const char *PLUGIN_VERSION{"1"}; |
| 16 | +static const char *PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"}; |
| 17 | +} // namespace |
| 18 | + |
| 19 | +MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string &name) |
| 20 | + : TRTPluginBase(name) {} |
| 21 | + |
| 22 | +MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name, |
| 23 | + const void *data, |
| 24 | + size_t length) |
| 25 | + : TRTPluginBase(name) {} |
| 26 | +MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {} |
| 27 | + |
| 28 | +nvinfer1::IPluginV2DynamicExt *MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT { |
| 29 | + MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(mLayerName); |
| 30 | + plugin->setPluginNamespace(getPluginNamespace()); |
| 31 | + |
| 32 | + return plugin; |
| 33 | +} |
| 34 | + |
| 35 | +nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions( |
| 36 | + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, |
| 37 | + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { |
| 38 | + nvinfer1::DimsExprs ret; |
| 39 | + ret.nbDims = 3; |
| 40 | + ret.d[0] = inputs[0].d[0]; |
| 41 | + ret.d[1] = inputs[3].d[1]; |
| 42 | + |
| 43 | + ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD, |
| 44 | + *inputs[0].d[2], *inputs[0].d[3]); |
| 45 | + |
| 46 | + return ret; |
| 47 | +} |
| 48 | + |
| 49 | +bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination( |
| 50 | + int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { |
| 51 | + |
| 52 | + if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) |
| 53 | + { |
| 54 | + if ((pos == 1) || (pos == 2)) |
| 55 | + { |
| 56 | + return (ioDesc[pos].type == nvinfer1::DataType::kINT32); |
| 57 | + } |
| 58 | + else |
| 59 | + { |
| 60 | + return ((ioDesc[pos].type == ioDesc[0].type) && |
| 61 | + ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || (ioDesc[pos].type == nvinfer1::DataType::kHALF))); |
| 62 | + } |
| 63 | + } |
| 64 | + else |
| 65 | + { |
| 66 | + return false; |
| 67 | + } |
| 68 | +} |
| 69 | + |
| 70 | +void MultiScaleDeformableAttnPluginDynamic::configurePlugin( |
| 71 | + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, |
| 72 | + const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT { |
| 73 | +} |
| 74 | + |
| 75 | +size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize( |
| 76 | + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, |
| 77 | + const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT { |
| 78 | + return 0; |
| 79 | +} |
| 80 | + |
| 81 | +int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, |
| 82 | + const nvinfer1::PluginTensorDesc *outputDesc, |
| 83 | + const void *const *inputs, void *const *outputs, |
| 84 | + void *workSpace, |
| 85 | + cudaStream_t stream) TRT_NOEXCEPT { |
| 86 | + int32_t const batch = inputDesc[0].dims.d[0]; |
| 87 | + int32_t spatial_size = inputDesc[0].dims.d[1]; |
| 88 | + int32_t num_heads = inputDesc[0].dims.d[2]; |
| 89 | + int32_t channels = inputDesc[0].dims.d[3]; |
| 90 | + int32_t num_levels = inputDesc[1].dims.d[0]; |
| 91 | + int32_t num_query = inputDesc[3].dims.d[1]; |
| 92 | + int32_t num_point = inputDesc[3].dims.d[4]; |
| 93 | + int32_t rc = 0; |
| 94 | + if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) |
| 95 | + { |
| 96 | + float const* value = static_cast<float const*>(inputs[0]); |
| 97 | + int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]); |
| 98 | + int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]); |
| 99 | + float const* samplingLoc = static_cast<float const*>(inputs[3]); |
| 100 | + float const* attnWeight = static_cast<float const*>(inputs[4]); |
| 101 | + float* output = static_cast<float*>(outputs[0]); |
| 102 | + |
| 103 | + rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, |
| 104 | + batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream); |
| 105 | + } |
| 106 | + else if (inputDesc[0].type == nvinfer1::DataType::kHALF) |
| 107 | + { |
| 108 | + const __half* value = static_cast<const __half*>(inputs[0]); |
| 109 | + int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]); |
| 110 | + int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]); |
| 111 | + const __half* samplingLoc = static_cast<const __half*>(inputs[3]); |
| 112 | + const __half* attnWeight = static_cast<const __half*>(inputs[4]); |
| 113 | + __half* output = static_cast<__half*>(outputs[0]); |
| 114 | + |
| 115 | + rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, |
| 116 | + batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream); |
| 117 | + } |
| 118 | + |
| 119 | + return rc; |
| 120 | +} |
| 121 | + |
| 122 | +nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType( |
| 123 | + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { |
| 124 | + return inputTypes[0]; |
| 125 | +} |
| 126 | + |
| 127 | +// IPluginV2 Methods |
| 128 | +const char *MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT { |
| 129 | + return PLUGIN_NAME; |
| 130 | +} |
| 131 | + |
| 132 | +const char *MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { |
| 133 | + return PLUGIN_VERSION; |
| 134 | +} |
| 135 | + |
| 136 | +int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } |
| 137 | + |
| 138 | +size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { |
| 139 | + return 0; |
| 140 | +} |
| 141 | + |
| 142 | +void MultiScaleDeformableAttnPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {} |
| 143 | + |
| 144 | +void MultiScaleDeformableAttnPluginDynamic::attachToContext( |
| 145 | + cudnnContext *cudnnContext, cublasContext *cublasContext, |
| 146 | + nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {} |
| 147 | + |
| 148 | +void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {} |
| 149 | + |
| 150 | +////////////////////// creator ///////////////////////////// |
| 151 | + |
| 152 | +MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() { |
| 153 | + mPluginAttributes.clear(); |
| 154 | + mFC.nbFields = mPluginAttributes.size(); |
| 155 | + mFC.fields = mPluginAttributes.data(); |
| 156 | +} |
| 157 | + |
| 158 | +const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { |
| 159 | + return PLUGIN_NAME; |
| 160 | +} |
| 161 | + |
| 162 | +const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { |
| 163 | + return PLUGIN_VERSION; |
| 164 | +} |
| 165 | + |
| 166 | +nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::createPlugin( |
| 167 | + const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { |
| 168 | + |
| 169 | + MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(name); |
| 170 | + plugin->setPluginNamespace(getPluginNamespace()); |
| 171 | + return plugin; |
| 172 | +} |
| 173 | + |
| 174 | +nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin( |
| 175 | + const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { |
| 176 | + auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength); |
| 177 | + plugin->setPluginNamespace(getPluginNamespace()); |
| 178 | + return plugin; |
| 179 | +} |
| 180 | +REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator); |
| 181 | +} // namespace mmdeploy |
0 commit comments