1
1
// Copyright (c) OpenMMLab. All rights reserved.
2
2
3
- #include " cuda_runtime .h"
3
+ #include " cudnn .h"
4
4
#include " mmdeploy/codebase/mmaction/format_shape.h"
5
5
#include " mmdeploy/core/utils/device_utils.h"
6
6
7
7
using namespace std ;
8
8
9
- namespace mmdeploy {
10
- namespace cuda {
9
+ namespace mmdeploy ::mmaction::cuda {
11
10
12
- template <typename T>
13
- void Transpose (const T* src, const int * src_strides, T* dst, const int * dst_strides, int ndim,
14
- int total, cudaStream_t stream);
11
+ #define CUDNN_CHECK (condition ) \
12
+ do { \
13
+ if (condition != CUDNN_STATUS_SUCCESS) { \
14
+ MMDEPLOY_ERROR (" cudnn error, msg = {}" , cudnnGetErrorString (condition)); \
15
+ } \
16
+ } while (0 );
15
17
16
- class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
18
+ class FormatShapeImpl : public FormatShapeOp {
17
19
public:
18
- explicit FormatShapeImpl (const Value& args) : ::mmdeploy::FormatShapeImpl(args) {}
20
+ explicit FormatShapeImpl (std::string input_format) : FormatShapeOp(std::move(input_format)) {
21
+ CUDNN_CHECK (cudnnCreate (&handle_));
22
+ CUDNN_CHECK (cudnnSetStream (handle_, GetNative<cudaStream_t>(stream ())));
23
+ CUDNN_CHECK (cudnnCreateTensorDescriptor (&src_desc_));
24
+ CUDNN_CHECK (cudnnCreateTensorDescriptor (&dst_desc_));
25
+ }
26
+
27
+ ~FormatShapeImpl () override {
28
+ CUDNN_CHECK (cudnnDestroy (handle_));
29
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (src_desc_));
30
+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dst_desc_));
31
+ }
19
32
20
33
protected:
21
- Result<Tensor> Format (const std::vector<Tensor>& tensors, int clip_len, int num_clips) {
22
- int N = tensors.size ();
23
- int H = tensors[0 ].shape (1 );
24
- int W = tensors[0 ].shape (2 );
25
- int C = tensors[0 ].shape (3 );
34
+ Result<void > apply (const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
35
+ int num_clips) override {
36
+ auto N = static_cast <int64_t >(inputs.size ());
37
+ auto H = inputs[0 ].shape (1 );
38
+ auto W = inputs[0 ].shape (2 );
39
+ auto C = inputs[0 ].shape (3 );
26
40
27
41
auto t0 = std::chrono::high_resolution_clock::now ();
28
42
TensorDesc desc = {device_, DataType::kFLOAT , {N, H, W, C}};
@@ -31,39 +45,39 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
31
45
int n_item = H * W * C;
32
46
int copy_size = n_item * sizeof (float );
33
47
for (int i = 0 ; i < N; i++) {
34
- auto src_buffer = tensors [i].buffer ();
48
+ auto src_buffer = inputs [i].buffer ();
35
49
auto dst_buffer = imgs.buffer ();
36
- OUTCOME_TRY (stream_ .Copy (src_buffer, dst_buffer, copy_size, 0 , offset));
50
+ OUTCOME_TRY (stream () .Copy (src_buffer, dst_buffer, copy_size, 0 , offset));
37
51
offset += copy_size;
38
52
}
39
53
40
- Tensor dst;
41
- if (arg_. input_format == " NCHW" ) {
42
- OUTCOME_TRY (dst , FormatNCHW (imgs, clip_len, num_clips));
54
+ // Tensor dst;
55
+ if (input_format_ == " NCHW" ) {
56
+ OUTCOME_TRY (output , FormatNCHW (imgs, clip_len, num_clips));
43
57
}
44
- if (arg_. input_format == " NCTHW" ) {
45
- OUTCOME_TRY (dst , FormatNCTHW (imgs, clip_len, num_clips));
58
+ if (input_format_ == " NCTHW" ) {
59
+ OUTCOME_TRY (output , FormatNCTHW (imgs, clip_len, num_clips));
46
60
}
47
- TensorShape expand_dim = dst .shape ();
61
+ TensorShape expand_dim = output .shape ();
48
62
expand_dim.insert (expand_dim.begin (), 1 );
49
- dst .Reshape (expand_dim);
63
+ output .Reshape (expand_dim);
50
64
51
- return dst ;
65
+ return success () ;
52
66
}
53
67
54
68
Result<Tensor> FormatNCHW (Tensor& src, int clip_len, int num_clips) {
55
- int N = src.shape (0 );
56
- int H = src.shape (1 );
57
- int W = src.shape (2 );
58
- int C = src.shape (3 );
69
+ auto N = src.shape (0 );
70
+ auto H = src.shape (1 );
71
+ auto W = src.shape (2 );
72
+ auto C = src.shape (3 );
59
73
return Transpose (src, {N, H, W, C}, {0 , 3 , 1 , 2 });
60
74
};
61
75
62
76
Result<Tensor> FormatNCTHW (Tensor& src, int clip_len, int num_clips) {
63
- int N = src.shape (0 );
64
- int H = src.shape (1 );
65
- int W = src.shape (2 );
66
- int C = src.shape (3 );
77
+ auto N = src.shape (0 );
78
+ auto H = src.shape (1 );
79
+ auto W = src.shape (2 );
80
+ auto C = src.shape (3 );
67
81
int L = clip_len;
68
82
if (N % L != 0 ) {
69
83
return Status (eInvalidArgument);
@@ -74,7 +88,7 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
74
88
return Transpose (src, {M, L, H, W, C}, {0 , 4 , 1 , 2 , 3 });
75
89
};
76
90
77
- Result<Tensor> Transpose (Tensor& src, const std::vector< int > & src_dims,
91
+ Result<Tensor> Transpose (Tensor& src, const TensorShape & src_dims,
78
92
const std::vector<int >& permutation) {
79
93
Tensor dst (src.desc ());
80
94
TensorShape shape (src.shape ().size ());
@@ -83,7 +97,15 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
83
97
}
84
98
dst.Reshape (shape);
85
99
86
- int ndim = src_dims.size ();
100
+ SetCudnnTensorDescriptor (src_dims, permutation);
101
+ CUDNN_CHECK (cudnnTransformTensor (handle_, &one_, src_desc_, src.data <float >(), &zero_,
102
+ dst_desc_, dst.data <float >()));
103
+
104
+ return dst;
105
+ }
106
+
107
+ void SetCudnnTensorDescriptor (const TensorShape& src_dims, const std::vector<int >& permutation) {
108
+ auto ndim = src_dims.size ();
87
109
std::vector<int > dst_dims (ndim);
88
110
for (int i = 0 ; i < ndim; i++) {
89
111
dst_dims[i] = src_dims[permutation[i]];
@@ -102,19 +124,21 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
102
124
src_strides[i] = buffer[permutation[i]];
103
125
}
104
126
105
- Buffer _src_strides (Device (" cuda" ), sizeof (int ) * ndim);
106
- Buffer _dst_strides (Device (" cuda" ), sizeof (int ) * ndim);
107
- OUTCOME_TRY (stream_.Copy (src_strides.data (), _src_strides));
108
- OUTCOME_TRY (stream_.Copy (dst_strides.data (), _dst_strides));
109
-
110
- ::mmdeploy::cuda::Transpose (src.data<float >(), GetNative<int*>(_src_strides), dst.data<float>(),
111
- GetNative<int*>(_dst_strides), ndim, src.size(),
112
- (cudaStream_t)stream_.GetNative());
113
- return dst;
127
+ CUDNN_CHECK (cudnnSetTensorNdDescriptor (src_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data (),
128
+ src_strides.data ()));
129
+ CUDNN_CHECK (cudnnSetTensorNdDescriptor (dst_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data (),
130
+ dst_strides.data ()));
114
131
}
132
+
133
+ constexpr static float one_{1.0 };
134
+ constexpr static float zero_{0.0 };
135
+ cudnnHandle_t handle_{};
136
+ cudnnTensorDescriptor_t src_desc_{};
137
+ cudnnTensorDescriptor_t dst_desc_{};
115
138
};
116
139
117
- MMDEPLOY_REGISTER_TRANSFORM_IMPL (::mmdeploy::FormatShapeImpl, (cuda, 0 ), FormatShapeImpl);
140
+ MMDEPLOY_REGISTER_FACTORY_FUNC (FormatShapeOp, (cuda, 0 ), [](std::string input_format) {
141
+ return std::make_unique<FormatShapeImpl>(std::move (input_format));
142
+ });
118
143
119
- } // namespace cuda
120
- } // namespace mmdeploy
144
+ } // namespace mmdeploy::mmaction::cuda
0 commit comments