Skip to content

Commit 57f7712

Browse files
irexyclvhan028
authored andcommitted
cherry-pick: Remove cudnn dependency for transform 'mmaction2::format_shape' (open-mmlab#1509)
1 parent 04c9793 commit 57f7712

File tree

7 files changed

+108
-163
lines changed

7 files changed

+108
-163
lines changed

csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,64 +12,9 @@ class FormatShapeImpl : public FormatShapeOp {
1212
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
1313

1414
protected:
15-
Result<void> apply(const std::vector<Tensor>& tensors, Tensor& output, int clip_len,
16-
int num_clips) override {
17-
auto N = static_cast<int64_t>(tensors.size());
18-
auto H = tensors[0].shape(1);
19-
auto W = tensors[0].shape(2);
20-
auto C = tensors[0].shape(3);
15+
Device host_{0, 0};
2116

22-
TensorDesc desc = {kHost, DataType::kFLOAT, {N, H, W, C}};
23-
Tensor imgs(desc);
24-
auto offset = 0UL;
25-
auto n_item = H * W * C;
26-
auto copy_size = n_item * sizeof(float);
27-
for (int i = 0; i < N; i++) {
28-
auto src_buffer = tensors[i].buffer();
29-
auto dst_buffer = imgs.buffer();
30-
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
31-
offset += copy_size;
32-
}
33-
34-
OUTCOME_TRY(stream().Wait());
35-
36-
Tensor dst;
37-
if (input_format_ == "NCHW") {
38-
OUTCOME_TRY(dst, FormatNCHW(imgs, clip_len, num_clips));
39-
}
40-
if (input_format_ == "NCTHW") {
41-
OUTCOME_TRY(dst, FormatNCTHW(imgs, clip_len, num_clips));
42-
}
43-
TensorShape expand_dim = dst.shape();
44-
expand_dim.insert(expand_dim.begin(), 1);
45-
dst.Reshape(expand_dim);
46-
output = std::move(dst);
47-
48-
return success();
49-
}
50-
51-
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
52-
auto N = src.shape(0);
53-
auto H = src.shape(1);
54-
auto W = src.shape(2);
55-
auto C = src.shape(3);
56-
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
57-
};
58-
59-
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
60-
auto N = src.shape(0);
61-
auto H = src.shape(1);
62-
auto W = src.shape(2);
63-
auto C = src.shape(3);
64-
auto L = clip_len;
65-
if (N % L != 0) {
66-
return Status(eInvalidArgument);
67-
}
68-
int M = N / L;
69-
src.Reshape({M, L, H, W, C});
70-
71-
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
72-
};
17+
const Device& GetDevice() { return host_; }
7318

7419
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
7520
const std::vector<int>& permutation) {
@@ -113,8 +58,6 @@ class FormatShapeImpl : public FormatShapeOp {
11358
} while (i >= 0);
11459
return dst;
11560
}
116-
117-
constexpr static Device kHost{0, 0};
11861
};
11962

12063
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) {

csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
1111
target_compile_options(${PROJECT_NAME} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
1212
endif ()
1313
target_include_directories(${PROJECT_NAME} PRIVATE
14-
${CUDA_INCLUDE_DIRS})
14+
${CUDA_INCLUDE_DIRS})
1515
target_link_libraries(${PROJECT_NAME} PRIVATE
16-
mmdeploy::core)
16+
mmdeploy::core)
1717
target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME})
1818
mmdeploy_export(${PROJECT_NAME})

csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp

Lines changed: 15 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,23 @@
11
// Copyright (c) OpenMMLab. All rights reserved.
22

3-
#include "cudnn.h"
3+
#include "cuda_runtime.h"
44
#include "mmdeploy/codebase/mmaction/format_shape.h"
55
#include "mmdeploy/core/utils/device_utils.h"
66

77
using namespace std;
88

99
namespace mmdeploy::mmaction::cuda {
1010

11-
#define CUDNN_CHECK(condition) \
12-
do { \
13-
if (condition != CUDNN_STATUS_SUCCESS) { \
14-
MMDEPLOY_ERROR("cudnn error, msg = {}", cudnnGetErrorString(condition)); \
15-
} \
16-
} while (0);
11+
template <typename T>
12+
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
13+
int total, cudaStream_t stream);
1714

1815
class FormatShapeImpl : public FormatShapeOp {
1916
public:
20-
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {
21-
CUDNN_CHECK(cudnnCreate(&handle_));
22-
CUDNN_CHECK(cudnnSetStream(handle_, GetNative<cudaStream_t>(stream())));
23-
CUDNN_CHECK(cudnnCreateTensorDescriptor(&src_desc_));
24-
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dst_desc_));
25-
}
26-
27-
~FormatShapeImpl() override {
28-
CUDNN_CHECK(cudnnDestroy(handle_));
29-
CUDNN_CHECK(cudnnDestroyTensorDescriptor(src_desc_));
30-
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dst_desc_));
31-
}
17+
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
3218

3319
protected:
34-
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
35-
int num_clips) override {
36-
auto N = static_cast<int64_t>(inputs.size());
37-
auto H = inputs[0].shape(1);
38-
auto W = inputs[0].shape(2);
39-
auto C = inputs[0].shape(3);
40-
41-
auto t0 = std::chrono::high_resolution_clock::now();
42-
TensorDesc desc = {device_, DataType::kFLOAT, {N, H, W, C}};
43-
Tensor imgs(desc);
44-
int offset = 0;
45-
int n_item = H * W * C;
46-
int copy_size = n_item * sizeof(float);
47-
for (int i = 0; i < N; i++) {
48-
auto src_buffer = inputs[i].buffer();
49-
auto dst_buffer = imgs.buffer();
50-
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
51-
offset += copy_size;
52-
}
53-
54-
// Tensor dst;
55-
if (input_format_ == "NCHW") {
56-
OUTCOME_TRY(output, FormatNCHW(imgs, clip_len, num_clips));
57-
}
58-
if (input_format_ == "NCTHW") {
59-
OUTCOME_TRY(output, FormatNCTHW(imgs, clip_len, num_clips));
60-
}
61-
TensorShape expand_dim = output.shape();
62-
expand_dim.insert(expand_dim.begin(), 1);
63-
output.Reshape(expand_dim);
64-
65-
return success();
66-
}
67-
68-
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
69-
auto N = src.shape(0);
70-
auto H = src.shape(1);
71-
auto W = src.shape(2);
72-
auto C = src.shape(3);
73-
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
74-
};
75-
76-
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
77-
auto N = src.shape(0);
78-
auto H = src.shape(1);
79-
auto W = src.shape(2);
80-
auto C = src.shape(3);
81-
int L = clip_len;
82-
if (N % L != 0) {
83-
return Status(eInvalidArgument);
84-
}
85-
int M = N / L;
86-
src.Reshape({M, L, H, W, C});
87-
88-
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
89-
};
20+
const Device& GetDevice() { return device(); }
9021

9122
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
9223
const std::vector<int>& permutation) {
@@ -97,14 +28,6 @@ class FormatShapeImpl : public FormatShapeOp {
9728
}
9829
dst.Reshape(shape);
9930

100-
SetCudnnTensorDescriptor(src_dims, permutation);
101-
CUDNN_CHECK(cudnnTransformTensor(handle_, &one_, src_desc_, src.data<float>(), &zero_,
102-
dst_desc_, dst.data<float>()));
103-
104-
return dst;
105-
}
106-
107-
void SetCudnnTensorDescriptor(const TensorShape& src_dims, const std::vector<int>& permutation) {
10831
auto ndim = src_dims.size();
10932
std::vector<int> dst_dims(ndim);
11033
for (int i = 0; i < ndim; i++) {
@@ -124,17 +47,16 @@ class FormatShapeImpl : public FormatShapeOp {
12447
src_strides[i] = buffer[permutation[i]];
12548
}
12649

127-
CUDNN_CHECK(cudnnSetTensorNdDescriptor(src_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
128-
src_strides.data()));
129-
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dst_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
130-
dst_strides.data()));
131-
}
50+
Buffer _src_strides(Device("cuda"), sizeof(int) * ndim);
51+
Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim);
52+
OUTCOME_TRY(stream().Copy(src_strides.data(), _src_strides));
53+
OUTCOME_TRY(stream().Copy(dst_strides.data(), _dst_strides));
13254

133-
constexpr static float one_{1.0};
134-
constexpr static float zero_{0.0};
135-
cudnnHandle_t handle_{};
136-
cudnnTensorDescriptor_t src_desc_{};
137-
cudnnTensorDescriptor_t dst_desc_{};
55+
::mmdeploy::mmaction::cuda::Transpose(src.data<float>(), GetNative<int*>(_src_strides),
56+
dst.data<float>(), GetNative<int*>(_dst_strides), ndim,
57+
src.size(), (cudaStream_t)stream().GetNative());
58+
return dst;
59+
}
13860
};
13961

14062
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) {

csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
#include <stdint.h>
44
#include <stdio.h>
55

6-
namespace mmdeploy {
7-
namespace cuda {
6+
namespace mmdeploy::mmaction::cuda {
87

98
template <typename T>
109
__global__ void transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides,
1110
int ndim, int total) {
1211
int u = blockIdx.x * blockDim.x + threadIdx.x;
13-
if (u >= total) return;
12+
if (u >= total) {
13+
return;
14+
}
1415

1516
int remaining = u;
1617
int v = 0;
@@ -34,5 +35,4 @@ void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_stri
3435
template void Transpose<float>(const float* src, const int* src_strides, float* dst,
3536
const int* dst_strides, int ndim, int total, cudaStream_t stream);
3637

37-
} // namespace cuda
38-
} // namespace mmdeploy
38+
} // namespace mmdeploy::mmaction::cuda

csrc/mmdeploy/codebase/mmaction/format_shape.cpp

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,76 @@ namespace mmdeploy::mmaction {
1212
FormatShape::FormatShape(const Value& args) {
1313
auto input_format = args.value("input_format", std::string(""));
1414
if (input_format != "NCHW" && input_format != "NCTHW") {
15-
throw std::domain_error("'input_format' should be 'NCHW' or 'NCTHW'");
15+
MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'");
16+
throw_exception(eInvalidArgument);
1617
}
18+
format_ = operation::Managed<mmdeploy::mmaction::FormatShapeOp>::Create(input_format);
19+
}
20+
21+
Result<void> FormatShapeOp::apply(const std::vector<Tensor>& images, Tensor& output, int clip_len,
22+
int num_clips) {
23+
Tensor inputs;
24+
OUTCOME_TRY(MergeInputs(images, inputs));
25+
if (GetDevice().is_host()) {
26+
OUTCOME_TRY(stream().Wait());
27+
}
28+
29+
// Tensor dst;
30+
if (input_format_ == "NCHW") {
31+
OUTCOME_TRY(output, FormatNCHW(inputs, clip_len, num_clips));
32+
}
33+
if (input_format_ == "NCTHW") {
34+
OUTCOME_TRY(output, FormatNCTHW(inputs, clip_len, num_clips));
35+
}
36+
37+
TensorShape expand_dim = output.shape();
38+
expand_dim.insert(expand_dim.begin(), 1);
39+
output.Reshape(expand_dim);
40+
41+
return success();
42+
}
43+
44+
Result<void> FormatShapeOp::MergeInputs(const std::vector<Tensor>& images, Tensor& inputs) {
45+
auto N = static_cast<int64_t>(images.size());
46+
auto H = images[0].shape(1);
47+
auto W = images[0].shape(2);
48+
auto C = images[0].shape(3);
49+
50+
TensorDesc desc = {GetDevice(), DataType::kFLOAT, {N, H, W, C}};
51+
inputs = Tensor(desc);
52+
auto offset = 0UL;
53+
auto n_item = H * W * C;
54+
auto copy_size = n_item * sizeof(float);
55+
for (int i = 0; i < N; i++) {
56+
auto src_buffer = images[i].buffer();
57+
auto dst_buffer = inputs.buffer();
58+
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
59+
offset += copy_size;
60+
}
61+
return success();
62+
}
63+
64+
Result<Tensor> FormatShapeOp::FormatNCHW(Tensor& src, int clip_len, int num_clips) {
65+
auto N = src.shape(0);
66+
auto H = src.shape(1);
67+
auto W = src.shape(2);
68+
auto C = src.shape(3);
69+
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
70+
}
71+
72+
Result<Tensor> FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
73+
auto N = src.shape(0);
74+
auto H = src.shape(1);
75+
auto W = src.shape(2);
76+
auto C = src.shape(3);
77+
int L = clip_len;
78+
if (N % L != 0) {
79+
return Status(eInvalidArgument);
80+
}
81+
int M = N / L;
82+
src.Reshape({M, L, H, W, C});
83+
84+
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
1785
}
1886

1987
Result<void> FormatShape::Apply(Value& data) {
@@ -50,6 +118,7 @@ Result<void> FormatShape::Apply(Value& data) {
50118
}
51119

52120
Tensor dst;
121+
data = Value{};
53122
OUTCOME_TRY(format_.Apply(images, dst, clip_len, num_clips));
54123
data["img"] = std::move(dst);
55124

csrc/mmdeploy/codebase/mmaction/format_shape.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,19 @@ class FormatShapeOp : public operation::Operation {
1616
public:
1717
explicit FormatShapeOp(std::string input_format) : input_format_(std::move(input_format)){};
1818

19-
virtual Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
20-
int num_clips) = 0;
19+
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
20+
int num_clips);
21+
22+
virtual const Device& GetDevice() = 0;
23+
24+
virtual Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
25+
const std::vector<int>& permutation) = 0;
26+
27+
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips);
28+
29+
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips);
30+
31+
Result<void> MergeInputs(const std::vector<Tensor>& images, Tensor& inputs);
2132

2233
protected:
2334
std::string input_format_;

csrc/mmdeploy/preprocess/transform/lift.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace mmdeploy::transform {
99
class Lift : public Transform {
1010
public:
1111
explicit Lift(const Value& args) {
12-
const char* type = "compose";
12+
const char* type = "Compose";
1313
if (auto creator = gRegistry<Transform>().Get(type)) {
1414
compose_ = creator->Create(args);
1515
} else {

0 commit comments

Comments
 (0)