@@ -21,9 +21,10 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, i
21
21
spatialSize, numHeads, channels, numLevels, numQuery, numPoint, dataCol);
22
22
}
23
23
24
- template <>
25
- int32_t ms_deform_attn_cuda_forward<float >(const float * value, const int32_t * spatialShapes,
26
- const int32_t * levelStartIndex, const float * samplingLoc, const float * attnWeight, float * output, int32_t batch,
24
+
25
+ template <typename scalar_t >
26
+ int32_t ms_deform_attn_cuda_forward (const scalar_t * value, const int32_t * spatialShapes,
27
+ const int32_t * levelStartIndex, const scalar_t * samplingLoc, const scalar_t * attnWeight, scalar_t * output, int32_t batch,
27
28
int32_t mSpatialSize , int32_t mNumHeads , int32_t mChannels , int32_t mNumLevels , int32_t mNumQuery , int32_t mNumPoint ,
28
29
cudaStream_t stream)
29
30
{
@@ -37,33 +38,20 @@ int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* sp
37
38
for (int32_t n = 0 ; n < batch / mIm2colStep ; ++n)
38
39
{
39
40
auto columns = output + n * mIm2colStep * perOutputSize;
40
- ms_deformable_im2col_cuda<float >(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
41
+ ms_deformable_im2col_cuda<scalar_t >(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
41
42
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep ,
42
43
mSpatialSize , mNumHeads , mChannels , mNumLevels , mNumQuery , mNumPoint , columns);
43
44
}
44
45
45
46
return 0 ;
46
47
}
47
48
48
- template <>
49
- int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t * spatialShapes,
50
- const int32_t * levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
49
+ template int32_t ms_deform_attn_cuda_forward<float >(const float * value, const int32_t * spatialShapes,
50
+ const int32_t * levelStartIndex, const float * samplingLoc, const float * attnWeight, float * output, int32_t batch,
51
51
int32_t mSpatialSize , int32_t mNumHeads , int32_t mChannels , int32_t mNumLevels , int32_t mNumQuery , int32_t mNumPoint ,
52
- cudaStream_t stream)
53
- {
54
- auto perValueSize = mSpatialSize * mNumHeads * mChannels ;
55
- auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2 ;
56
- auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint ;
57
- auto perOutputSize = mNumQuery * mNumHeads * mChannels ;
52
+ cudaStream_t stream);
58
53
59
- int32_t mIm2colStep = batch;
60
- for (int32_t n = 0 ; n < batch / mIm2colStep ; ++n)
61
- {
62
- auto columns = output + n * mIm2colStep * perOutputSize;
63
- ms_deformable_im2col_cuda<__half>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
64
- samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep ,
65
- mSpatialSize , mNumHeads , mChannels , mNumLevels , mNumQuery , mNumPoint , columns);
66
- }
67
-
68
- return 0 ;
69
- }
54
+ template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t * spatialShapes,
55
+ const int32_t * levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
56
+ int32_t mSpatialSize , int32_t mNumHeads , int32_t mChannels , int32_t mNumLevels , int32_t mNumQuery , int32_t mNumPoint ,
57
+ cudaStream_t stream);
0 commit comments