Skip to content

Commit 16d7c1e

Browse files
committed
fix template
1 parent dcf0caf commit 16d7c1e

File tree

1 file changed

+12
-24
lines changed

1 file changed

+12
-24
lines changed

csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, i
2121
spatialSize, numHeads, channels, numLevels, numQuery, numPoint, dataCol);
2222
}
2323

24-
template <>
25-
int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes,
26-
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,
24+
25+
template <typename scalar_t>
26+
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes,
27+
const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch,
2728
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
2829
cudaStream_t stream)
2930
{
@@ -37,33 +38,20 @@ int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* sp
3738
for (int32_t n = 0; n < batch / mIm2colStep; ++n)
3839
{
3940
auto columns = output + n * mIm2colStep * perOutputSize;
40-
ms_deformable_im2col_cuda<float>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
41+
ms_deformable_im2col_cuda<scalar_t>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
4142
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep,
4243
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns);
4344
}
4445

4546
return 0;
4647
}
4748

48-
template <>
49-
int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes,
50-
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
49+
template int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes,
50+
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,
5151
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
52-
cudaStream_t stream)
53-
{
54-
auto perValueSize = mSpatialSize * mNumHeads * mChannels;
55-
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2;
56-
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint;
57-
auto perOutputSize = mNumQuery * mNumHeads * mChannels;
52+
cudaStream_t stream);
5853

59-
int32_t mIm2colStep = batch;
60-
for (int32_t n = 0; n < batch / mIm2colStep; ++n)
61-
{
62-
auto columns = output + n * mIm2colStep * perOutputSize;
63-
ms_deformable_im2col_cuda<__half>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
64-
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep,
65-
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns);
66-
}
67-
68-
return 0;
69-
}
54+
template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes,
55+
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
56+
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
57+
cudaStream_t stream);

0 commit comments

Comments
 (0)