Skip to content

Commit 04c9793

Browse files
committed
cherry-pick: Decouple preprocess operation and transformation (open-mmlab#1353)
1 parent 09c6bd7 commit 04c9793

File tree

126 files changed

+2918
-3725
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+2918
-3725
lines changed

csrc/mmdeploy/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ if (MMDEPLOY_BUILD_SDK)
1313
add_subdirectory(device)
1414
add_subdirectory(graph)
1515
add_subdirectory(model)
16+
add_subdirectory(operation)
1617
add_subdirectory(preprocess)
1718
add_subdirectory(net)
1819
add_subdirectory(codebase)

csrc/mmdeploy/codebase/mmaction/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
77
add_subdirectory(cpu)
88
add_subdirectory(cuda)
99
target_link_libraries(${PROJECT_NAME} PRIVATE
10-
mmdeploy::transform
10+
mmdeploy_operation
11+
mmdeploy_transform
1112
mmdeploy_opencv_utils)
1213

1314
add_library(mmdeploy::mmaction ALIAS ${PROJECT_NAME})

csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,69 +5,63 @@
55

66
using namespace std;
77

8-
namespace mmdeploy {
9-
namespace cpu {
8+
namespace mmdeploy::mmaction::cpu {
109

11-
class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
10+
class FormatShapeImpl : public FormatShapeOp {
1211
public:
13-
explicit FormatShapeImpl(const Value& args) : ::mmdeploy::FormatShapeImpl(args) {}
12+
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
1413

1514
protected:
16-
Result<Tensor> Format(const std::vector<Tensor>& tensors, int clip_len, int num_clips) {
17-
int N = tensors.size();
18-
int H = tensors[0].shape(1);
19-
int W = tensors[0].shape(2);
20-
int C = tensors[0].shape(3);
21-
22-
std::vector<Tensor> host_tensors;
23-
host_tensors.reserve(N);
24-
for (int i = 0; i < N; i++) {
25-
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensors[i], kHost, stream_));
26-
host_tensors.push_back(std::move(src_tensor));
27-
}
28-
OUTCOME_TRY(stream_.Wait());
15+
Result<void> apply(const std::vector<Tensor>& tensors, Tensor& output, int clip_len,
16+
int num_clips) override {
17+
auto N = static_cast<int64_t>(tensors.size());
18+
auto H = tensors[0].shape(1);
19+
auto W = tensors[0].shape(2);
20+
auto C = tensors[0].shape(3);
2921

3022
TensorDesc desc = {kHost, DataType::kFLOAT, {N, H, W, C}};
3123
Tensor imgs(desc);
32-
int offset = 0;
33-
int n_item = H * W * C;
34-
int copy_size = n_item * sizeof(float);
24+
auto offset = 0UL;
25+
auto n_item = H * W * C;
26+
auto copy_size = n_item * sizeof(float);
3527
for (int i = 0; i < N; i++) {
36-
auto src_buffer = host_tensors[i].buffer();
28+
auto src_buffer = tensors[i].buffer();
3729
auto dst_buffer = imgs.buffer();
38-
OUTCOME_TRY(stream_.Copy(src_buffer, dst_buffer, copy_size, 0, offset));
30+
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
3931
offset += copy_size;
4032
}
41-
OUTCOME_TRY(stream_.Wait());
33+
34+
OUTCOME_TRY(stream().Wait());
4235

4336
Tensor dst;
44-
if (arg_.input_format == "NCHW") {
37+
if (input_format_ == "NCHW") {
4538
OUTCOME_TRY(dst, FormatNCHW(imgs, clip_len, num_clips));
4639
}
47-
if (arg_.input_format == "NCTHW") {
40+
if (input_format_ == "NCTHW") {
4841
OUTCOME_TRY(dst, FormatNCTHW(imgs, clip_len, num_clips));
4942
}
5043
TensorShape expand_dim = dst.shape();
5144
expand_dim.insert(expand_dim.begin(), 1);
5245
dst.Reshape(expand_dim);
46+
output = std::move(dst);
5347

54-
return dst;
48+
return success();
5549
}
5650

5751
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
58-
int N = src.shape(0);
59-
int H = src.shape(1);
60-
int W = src.shape(2);
61-
int C = src.shape(3);
52+
auto N = src.shape(0);
53+
auto H = src.shape(1);
54+
auto W = src.shape(2);
55+
auto C = src.shape(3);
6256
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
6357
};
6458

6559
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
66-
int N = src.shape(0);
67-
int H = src.shape(1);
68-
int W = src.shape(2);
69-
int C = src.shape(3);
70-
int L = clip_len;
60+
auto N = src.shape(0);
61+
auto H = src.shape(1);
62+
auto W = src.shape(2);
63+
auto C = src.shape(3);
64+
auto L = clip_len;
7165
if (N % L != 0) {
7266
return Status(eInvalidArgument);
7367
}
@@ -77,7 +71,7 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
7771
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
7872
};
7973

80-
Result<Tensor> Transpose(Tensor& src, const std::vector<int>& src_dims,
74+
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
8175
const std::vector<int>& permutation) {
8276
Tensor dst(src.desc());
8377
TensorShape shape(src.shape().size());
@@ -123,7 +117,8 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
123117
constexpr static Device kHost{0, 0};
124118
};
125119

126-
MMDEPLOY_REGISTER_TRANSFORM_IMPL(::mmdeploy::FormatShapeImpl, (cpu, 0), FormatShapeImpl);
120+
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) {
121+
return std::make_unique<FormatShapeImpl>(std::move(input_format));
122+
});
127123

128-
} // namespace cpu
129-
} // namespace mmdeploy
124+
} // namespace mmdeploy::mmaction::cpu
Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,42 @@
11
// Copyright (c) OpenMMLab. All rights reserved.
22

3-
#include "cuda_runtime.h"
3+
#include "cudnn.h"
44
#include "mmdeploy/codebase/mmaction/format_shape.h"
55
#include "mmdeploy/core/utils/device_utils.h"
66

77
using namespace std;
88

9-
namespace mmdeploy {
10-
namespace cuda {
9+
namespace mmdeploy::mmaction::cuda {
1110

12-
template <typename T>
13-
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
14-
int total, cudaStream_t stream);
11+
#define CUDNN_CHECK(condition) \
12+
do { \
13+
if (condition != CUDNN_STATUS_SUCCESS) { \
14+
MMDEPLOY_ERROR("cudnn error, msg = {}", cudnnGetErrorString(condition)); \
15+
} \
16+
} while (0);
1517

16-
class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
18+
class FormatShapeImpl : public FormatShapeOp {
1719
public:
18-
explicit FormatShapeImpl(const Value& args) : ::mmdeploy::FormatShapeImpl(args) {}
20+
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {
21+
CUDNN_CHECK(cudnnCreate(&handle_));
22+
CUDNN_CHECK(cudnnSetStream(handle_, GetNative<cudaStream_t>(stream())));
23+
CUDNN_CHECK(cudnnCreateTensorDescriptor(&src_desc_));
24+
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dst_desc_));
25+
}
26+
27+
~FormatShapeImpl() override {
28+
CUDNN_CHECK(cudnnDestroy(handle_));
29+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(src_desc_));
30+
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dst_desc_));
31+
}
1932

2033
protected:
21-
Result<Tensor> Format(const std::vector<Tensor>& tensors, int clip_len, int num_clips) {
22-
int N = tensors.size();
23-
int H = tensors[0].shape(1);
24-
int W = tensors[0].shape(2);
25-
int C = tensors[0].shape(3);
34+
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
35+
int num_clips) override {
36+
auto N = static_cast<int64_t>(inputs.size());
37+
auto H = inputs[0].shape(1);
38+
auto W = inputs[0].shape(2);
39+
auto C = inputs[0].shape(3);
2640

2741
auto t0 = std::chrono::high_resolution_clock::now();
2842
TensorDesc desc = {device_, DataType::kFLOAT, {N, H, W, C}};
@@ -31,39 +45,39 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
3145
int n_item = H * W * C;
3246
int copy_size = n_item * sizeof(float);
3347
for (int i = 0; i < N; i++) {
34-
auto src_buffer = tensors[i].buffer();
48+
auto src_buffer = inputs[i].buffer();
3549
auto dst_buffer = imgs.buffer();
36-
OUTCOME_TRY(stream_.Copy(src_buffer, dst_buffer, copy_size, 0, offset));
50+
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
3751
offset += copy_size;
3852
}
3953

40-
Tensor dst;
41-
if (arg_.input_format == "NCHW") {
42-
OUTCOME_TRY(dst, FormatNCHW(imgs, clip_len, num_clips));
54+
// Tensor dst;
55+
if (input_format_ == "NCHW") {
56+
OUTCOME_TRY(output, FormatNCHW(imgs, clip_len, num_clips));
4357
}
44-
if (arg_.input_format == "NCTHW") {
45-
OUTCOME_TRY(dst, FormatNCTHW(imgs, clip_len, num_clips));
58+
if (input_format_ == "NCTHW") {
59+
OUTCOME_TRY(output, FormatNCTHW(imgs, clip_len, num_clips));
4660
}
47-
TensorShape expand_dim = dst.shape();
61+
TensorShape expand_dim = output.shape();
4862
expand_dim.insert(expand_dim.begin(), 1);
49-
dst.Reshape(expand_dim);
63+
output.Reshape(expand_dim);
5064

51-
return dst;
65+
return success();
5266
}
5367

5468
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
55-
int N = src.shape(0);
56-
int H = src.shape(1);
57-
int W = src.shape(2);
58-
int C = src.shape(3);
69+
auto N = src.shape(0);
70+
auto H = src.shape(1);
71+
auto W = src.shape(2);
72+
auto C = src.shape(3);
5973
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
6074
};
6175

6276
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
63-
int N = src.shape(0);
64-
int H = src.shape(1);
65-
int W = src.shape(2);
66-
int C = src.shape(3);
77+
auto N = src.shape(0);
78+
auto H = src.shape(1);
79+
auto W = src.shape(2);
80+
auto C = src.shape(3);
6781
int L = clip_len;
6882
if (N % L != 0) {
6983
return Status(eInvalidArgument);
@@ -74,7 +88,7 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
7488
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
7589
};
7690

77-
Result<Tensor> Transpose(Tensor& src, const std::vector<int>& src_dims,
91+
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
7892
const std::vector<int>& permutation) {
7993
Tensor dst(src.desc());
8094
TensorShape shape(src.shape().size());
@@ -83,7 +97,15 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
8397
}
8498
dst.Reshape(shape);
8599

86-
int ndim = src_dims.size();
100+
SetCudnnTensorDescriptor(src_dims, permutation);
101+
CUDNN_CHECK(cudnnTransformTensor(handle_, &one_, src_desc_, src.data<float>(), &zero_,
102+
dst_desc_, dst.data<float>()));
103+
104+
return dst;
105+
}
106+
107+
void SetCudnnTensorDescriptor(const TensorShape& src_dims, const std::vector<int>& permutation) {
108+
auto ndim = src_dims.size();
87109
std::vector<int> dst_dims(ndim);
88110
for (int i = 0; i < ndim; i++) {
89111
dst_dims[i] = src_dims[permutation[i]];
@@ -102,19 +124,21 @@ class FormatShapeImpl : public ::mmdeploy::FormatShapeImpl {
102124
src_strides[i] = buffer[permutation[i]];
103125
}
104126

105-
Buffer _src_strides(Device("cuda"), sizeof(int) * ndim);
106-
Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim);
107-
OUTCOME_TRY(stream_.Copy(src_strides.data(), _src_strides));
108-
OUTCOME_TRY(stream_.Copy(dst_strides.data(), _dst_strides));
109-
110-
::mmdeploy::cuda::Transpose(src.data<float>(), GetNative<int*>(_src_strides), dst.data<float>(),
111-
GetNative<int*>(_dst_strides), ndim, src.size(),
112-
(cudaStream_t)stream_.GetNative());
113-
return dst;
127+
CUDNN_CHECK(cudnnSetTensorNdDescriptor(src_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
128+
src_strides.data()));
129+
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dst_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
130+
dst_strides.data()));
114131
}
132+
133+
constexpr static float one_{1.0};
134+
constexpr static float zero_{0.0};
135+
cudnnHandle_t handle_{};
136+
cudnnTensorDescriptor_t src_desc_{};
137+
cudnnTensorDescriptor_t dst_desc_{};
115138
};
116139

117-
MMDEPLOY_REGISTER_TRANSFORM_IMPL(::mmdeploy::FormatShapeImpl, (cuda, 0), FormatShapeImpl);
140+
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) {
141+
return std::make_unique<FormatShapeImpl>(std::move(input_format));
142+
});
118143

119-
} // namespace cuda
120-
} // namespace mmdeploy
144+
} // namespace mmdeploy::mmaction::cuda

0 commit comments

Comments
 (0)