1
1
// Copyright (c) OpenMMLab. All rights reserved.
2
2
3
- #include " cudnn .h"
3
+ #include " cuda_runtime .h"
4
4
#include " mmdeploy/codebase/mmaction/format_shape.h"
5
5
#include " mmdeploy/core/utils/device_utils.h"
6
6
7
7
using namespace std ;
8
8
9
9
namespace mmdeploy ::mmaction::cuda {
10
10
11
- #define CUDNN_CHECK (condition ) \
12
- do { \
13
- if (condition != CUDNN_STATUS_SUCCESS) { \
14
- MMDEPLOY_ERROR (" cudnn error, msg = {}" , cudnnGetErrorString (condition)); \
15
- } \
16
- } while (0 );
11
+ template <typename T>
12
+ void Transpose (const T* src, const int * src_strides, T* dst, const int * dst_strides, int ndim,
13
+ int total, cudaStream_t stream);
17
14
18
15
class FormatShapeImpl : public FormatShapeOp {
19
16
public:
20
- explicit FormatShapeImpl (std::string input_format) : FormatShapeOp(std::move(input_format)) {
21
- CUDNN_CHECK (cudnnCreate (&handle_));
22
- CUDNN_CHECK (cudnnSetStream (handle_, GetNative<cudaStream_t>(stream ())));
23
- CUDNN_CHECK (cudnnCreateTensorDescriptor (&src_desc_));
24
- CUDNN_CHECK (cudnnCreateTensorDescriptor (&dst_desc_));
25
- }
26
-
27
- ~FormatShapeImpl () override {
28
- CUDNN_CHECK (cudnnDestroy (handle_));
29
- CUDNN_CHECK (cudnnDestroyTensorDescriptor (src_desc_));
30
- CUDNN_CHECK (cudnnDestroyTensorDescriptor (dst_desc_));
31
- }
17
+ explicit FormatShapeImpl (std::string input_format) : FormatShapeOp(std::move(input_format)) {}
32
18
33
19
protected:
34
- Result<void > apply (const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
35
- int num_clips) override {
36
- auto N = static_cast <int64_t >(inputs.size ());
37
- auto H = inputs[0 ].shape (1 );
38
- auto W = inputs[0 ].shape (2 );
39
- auto C = inputs[0 ].shape (3 );
40
-
41
- auto t0 = std::chrono::high_resolution_clock::now ();
42
- TensorDesc desc = {device_, DataType::kFLOAT , {N, H, W, C}};
43
- Tensor imgs (desc);
44
- int offset = 0 ;
45
- int n_item = H * W * C;
46
- int copy_size = n_item * sizeof (float );
47
- for (int i = 0 ; i < N; i++) {
48
- auto src_buffer = inputs[i].buffer ();
49
- auto dst_buffer = imgs.buffer ();
50
- OUTCOME_TRY (stream ().Copy (src_buffer, dst_buffer, copy_size, 0 , offset));
51
- offset += copy_size;
52
- }
53
-
54
- // Tensor dst;
55
- if (input_format_ == " NCHW" ) {
56
- OUTCOME_TRY (output, FormatNCHW (imgs, clip_len, num_clips));
57
- }
58
- if (input_format_ == " NCTHW" ) {
59
- OUTCOME_TRY (output, FormatNCTHW (imgs, clip_len, num_clips));
60
- }
61
- TensorShape expand_dim = output.shape ();
62
- expand_dim.insert (expand_dim.begin (), 1 );
63
- output.Reshape (expand_dim);
64
-
65
- return success ();
66
- }
67
-
68
- Result<Tensor> FormatNCHW (Tensor& src, int clip_len, int num_clips) {
69
- auto N = src.shape (0 );
70
- auto H = src.shape (1 );
71
- auto W = src.shape (2 );
72
- auto C = src.shape (3 );
73
- return Transpose (src, {N, H, W, C}, {0 , 3 , 1 , 2 });
74
- };
75
-
76
- Result<Tensor> FormatNCTHW (Tensor& src, int clip_len, int num_clips) {
77
- auto N = src.shape (0 );
78
- auto H = src.shape (1 );
79
- auto W = src.shape (2 );
80
- auto C = src.shape (3 );
81
- int L = clip_len;
82
- if (N % L != 0 ) {
83
- return Status (eInvalidArgument);
84
- }
85
- int M = N / L;
86
- src.Reshape ({M, L, H, W, C});
87
-
88
- return Transpose (src, {M, L, H, W, C}, {0 , 4 , 1 , 2 , 3 });
89
- };
20
+ const Device& GetDevice () { return device (); }
90
21
91
22
Result<Tensor> Transpose (Tensor& src, const TensorShape& src_dims,
92
23
const std::vector<int >& permutation) {
@@ -97,14 +28,6 @@ class FormatShapeImpl : public FormatShapeOp {
97
28
}
98
29
dst.Reshape (shape);
99
30
100
- SetCudnnTensorDescriptor (src_dims, permutation);
101
- CUDNN_CHECK (cudnnTransformTensor (handle_, &one_, src_desc_, src.data <float >(), &zero_,
102
- dst_desc_, dst.data <float >()));
103
-
104
- return dst;
105
- }
106
-
107
- void SetCudnnTensorDescriptor (const TensorShape& src_dims, const std::vector<int >& permutation) {
108
31
auto ndim = src_dims.size ();
109
32
std::vector<int > dst_dims (ndim);
110
33
for (int i = 0 ; i < ndim; i++) {
@@ -124,17 +47,16 @@ class FormatShapeImpl : public FormatShapeOp {
124
47
src_strides[i] = buffer[permutation[i]];
125
48
}
126
49
127
- CUDNN_CHECK (cudnnSetTensorNdDescriptor (src_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data (),
128
- src_strides.data ()));
129
- CUDNN_CHECK (cudnnSetTensorNdDescriptor (dst_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data (),
130
- dst_strides.data ()));
131
- }
50
+ Buffer _src_strides (Device (" cuda" ), sizeof (int ) * ndim);
51
+ Buffer _dst_strides (Device (" cuda" ), sizeof (int ) * ndim);
52
+ OUTCOME_TRY (stream ().Copy (src_strides.data (), _src_strides));
53
+ OUTCOME_TRY (stream ().Copy (dst_strides.data (), _dst_strides));
132
54
133
- constexpr static float one_{ 1.0 };
134
- constexpr static float zero_{ 0.0 };
135
- cudnnHandle_t handle_{} ;
136
- cudnnTensorDescriptor_t src_desc_{} ;
137
- cudnnTensorDescriptor_t dst_desc_{};
55
+ ::mmdeploy::mmaction::cuda::Transpose (src.data< float >(), GetNative<int*>(_src_strides),
56
+ dst.data<float>(), GetNative<int*>(_dst_strides), ndim,
57
+ src.size(), (cudaStream_t)stream().GetNative()) ;
58
+ return dst ;
59
+ }
138
60
};
139
61
140
62
MMDEPLOY_REGISTER_FACTORY_FUNC (FormatShapeOp, (cuda, 0 ), [](std::string input_format) {
0 commit comments