Skip to content

Commit 7e66cfc

Browse files
authored
Cherry-pick PR #1460: optimize pose tracker
Cherry-pick PR 1460 to dev-1.x
2 parents ef260d8 + f7ea130 commit 7e66cfc

File tree

15 files changed

+1055
-232
lines changed

15 files changed

+1055
-232
lines changed

csrc/mmdeploy/codebase/mmpose/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ project(mmdeploy_mmpose)
66
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
77
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
88
target_link_libraries(${PROJECT_NAME} PRIVATE
9-
mmdeploy::transform mmdeploy_opencv_utils)
9+
mmdeploy::transform
10+
mmdeploy_operation
11+
mmdeploy_opencv_utils)
1012
add_library(mmdeploy::mmpose ALIAS ${PROJECT_NAME})
1113

1214
set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} pose_detector CACHE INTERNAL "")

csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "mmdeploy/core/tensor.h"
88
#include "mmdeploy/core/utils/device_utils.h"
99
#include "mmdeploy/core/utils/formatter.h"
10+
#include "mmdeploy/operation/managed.h"
11+
#include "mmdeploy/operation/vision.h"
1012
#include "mmdeploy/preprocess/transform/transform.h"
1113
#include "opencv2/imgproc.hpp"
1214
#include "opencv_utils.h"
@@ -32,18 +34,15 @@ class TopDownAffine : public transform::Transform {
3234
stream_ = args["context"]["stream"].get<Stream>();
3335
assert(args.contains("image_size"));
3436
from_value(args["image_size"], image_size_);
37+
warp_affine_ = operation::Managed<operation::WarpAffine>::Create("bilinear");
3538
}
3639

3740
~TopDownAffine() override = default;
3841

3942
Result<void> Apply(Value& data) override {
4043
MMDEPLOY_DEBUG("top_down_affine input: {}", data);
4144

42-
Device host{"cpu"};
43-
auto _img = data["img"].get<Tensor>();
44-
OUTCOME_TRY(auto img, MakeAvailableOnDevice(_img, host, stream_));
45-
stream_.Wait().value();
46-
auto src = cpu::Tensor2CVMat(img);
45+
auto img = data["img"].get<Tensor>();
4746

4847
// prepare data
4948
vector<float> bbox;
@@ -62,21 +61,20 @@ class TopDownAffine : public transform::Transform {
6261

6362
auto r = data["rotation"].get<float>();
6463

65-
cv::Mat dst;
64+
Tensor dst;
6665
if (use_udp_) {
6766
cv::Mat trans =
6867
GetWarpMatrix(r, {c[0] * 2.f, c[1] * 2.f}, {image_size_[0] - 1.f, image_size_[1] - 1.f},
6968
{s[0] * 200.f, s[1] * 200.f});
70-
71-
cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR);
69+
OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr<float>(), image_size_[1], image_size_[0]));
7270
} else {
7371
cv::Mat trans =
7472
GetAffineTransform({c[0], c[1]}, {s[0], s[1]}, r, {image_size_[0], image_size_[1]});
75-
cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR);
73+
OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr<float>(), image_size_[1], image_size_[0]));
7674
}
7775

78-
data["img"] = cpu::CVMat2Tensor(dst);
79-
data["img_shape"] = {1, image_size_[1], image_size_[0], dst.channels()};
76+
data["img_shape"] = {1, image_size_[1], image_size_[0], dst.shape(3)};
77+
data["img"] = std::move(dst);
8078
data["center"] = to_value(c);
8179
data["scale"] = to_value(s);
8280
MMDEPLOY_DEBUG("output: {}", data);
@@ -106,7 +104,7 @@ class TopDownAffine : public transform::Transform {
106104
theta = theta * 3.1415926 / 180;
107105
float scale_x = size_dst.width / size_target.width;
108106
float scale_y = size_dst.height / size_target.height;
109-
cv::Mat matrix = cv::Mat(2, 3, CV_32FC1);
107+
cv::Mat matrix = cv::Mat(2, 3, CV_32F);
110108
matrix.at<float>(0, 0) = std::cos(theta) * scale_x;
111109
matrix.at<float>(0, 1) = -std::sin(theta) * scale_x;
112110
matrix.at<float>(0, 2) =
@@ -142,6 +140,7 @@ class TopDownAffine : public transform::Transform {
142140

143141
cv::Mat trans = inv ? cv::getAffineTransform(dst_points, src_points)
144142
: cv::getAffineTransform(src_points, dst_points);
143+
trans.convertTo(trans, CV_32F);
145144
return trans;
146145
}
147146

@@ -160,6 +159,7 @@ class TopDownAffine : public transform::Transform {
160159
}
161160

162161
protected:
162+
operation::Managed<operation::WarpAffine> warp_affine_;
163163
bool use_udp_{false};
164164
vector<int> image_size_;
165165
std::string backend_;

csrc/mmdeploy/operation/cpu/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ set(SRCS resize.cpp
99
hwc2chw.cpp
1010
normalize.cpp
1111
crop.cpp
12-
flip.cpp)
12+
flip.cpp
13+
warp_affine.cpp)
1314

1415
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
1516

csrc/mmdeploy/operation/cpu/resize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace mmdeploy::operation::cpu {
77

88
class ResizeImpl : public Resize {
99
public:
10-
ResizeImpl(std::string interp) : interp_(std::move(interp)) {}
10+
explicit ResizeImpl(std::string interp) : interp_(std::move(interp)) {}
1111

1212
Result<void> apply(const Tensor& src, Tensor& dst, int dst_h, int dst_w) override {
1313
auto src_mat = mmdeploy::cpu::Tensor2CVMat(src);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include "mmdeploy/operation/vision.h"
4+
#include "mmdeploy/utils/opencv/opencv_utils.h"
5+
6+
namespace mmdeploy::operation::cpu {
7+
8+
class WarpAffineImpl : public WarpAffine {
9+
public:
10+
explicit WarpAffineImpl(int method) : method_(method) {}
11+
12+
Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h,
13+
int dst_w) override {
14+
auto src_mat = mmdeploy::cpu::Tensor2CVMat(src);
15+
cv::Mat_<float> _matrix(2, 3, const_cast<float*>(affine_matrix));
16+
auto dst_mat = mmdeploy::cpu::WarpAffine(src_mat, _matrix, dst_h, dst_w, method_);
17+
dst = mmdeploy::cpu::CVMat2Tensor(dst_mat);
18+
return success();
19+
}
20+
21+
private:
22+
int method_;
23+
};
24+
25+
MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cpu, 0), [](const string_view& interp) {
26+
return std::make_unique<WarpAffineImpl>(::mmdeploy::cpu::GetInterpolationMethod(interp).value());
27+
});
28+
29+
} // namespace mmdeploy::operation::cpu

csrc/mmdeploy/operation/cuda/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ set(SRCS resize.cpp
1717
normalize.cu
1818
crop.cpp
1919
crop.cu
20-
flip.cpp)
20+
flip.cpp
21+
warp_affine.cpp)
2122

2223
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
2324

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include "mmdeploy/core/utils/formatter.h"
4+
#include "mmdeploy/operation/vision.h"
5+
#include "ppl/cv/cuda/warpaffine.h"
6+
7+
namespace mmdeploy::operation::cuda {
8+
9+
class WarpAffineImpl : public WarpAffine {
10+
public:
11+
explicit WarpAffineImpl(ppl::cv::InterpolationType interp) : interp_(interp) {}
12+
13+
Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h,
14+
int dst_w) override {
15+
assert(src.device() == device());
16+
17+
TensorDesc desc{device(), src.data_type(), {1, dst_h, dst_w, src.shape(3)}, src.name()};
18+
Tensor dst_tensor(desc);
19+
20+
const auto m = affine_matrix;
21+
auto inv = Invert(affine_matrix);
22+
23+
auto cuda_stream = GetNative<cudaStream_t>(stream());
24+
if (src.data_type() == DataType::kINT8) {
25+
OUTCOME_TRY(Dispatch<uint8_t>(src, dst_tensor, inv.data(), cuda_stream));
26+
} else if (src.data_type() == DataType::kFLOAT) {
27+
OUTCOME_TRY(Dispatch<float>(src, dst_tensor, inv.data(), cuda_stream));
28+
} else {
29+
MMDEPLOY_ERROR("unsupported data type {}", src.data_type());
30+
return Status(eNotSupported);
31+
}
32+
33+
dst = std::move(dst_tensor);
34+
return success();
35+
}
36+
37+
private:
38+
// ppl.cv uses inverted transform
39+
// https://github.com/opencv/opencv/blob/bc6544c0bcfa9ca5db5e0d0551edf5c8e7da3852/modules/imgproc/src/imgwarp.cpp#L3478
40+
static std::array<float, 6> Invert(const float affine_matrix[6]) {
41+
const auto* M = affine_matrix;
42+
std::array<float, 6> inv{};
43+
auto iM = inv.data();
44+
45+
auto D = M[0] * M[3 + 1] - M[1] * M[3];
46+
D = D != 0.f ? 1.f / D : 0.f;
47+
auto A11 = M[3 + 1] * D, A22 = M[0] * D, A12 = -M[1] * D, A21 = -M[3] * D;
48+
auto b1 = -A11 * M[2] - A12 * M[3 + 2];
49+
auto b2 = -A21 * M[2] - A22 * M[3 + 2];
50+
51+
iM[0] = A11;
52+
iM[1] = A12;
53+
iM[2] = b1;
54+
iM[3] = A21;
55+
iM[3 + 1] = A22;
56+
iM[3 + 2] = b2;
57+
58+
return inv;
59+
}
60+
61+
template <typename T>
62+
auto Select(int channels) -> decltype(&ppl::cv::cuda::WarpAffine<T, 1>) {
63+
switch (channels) {
64+
case 1:
65+
return &ppl::cv::cuda::WarpAffine<T, 1>;
66+
case 3:
67+
return &ppl::cv::cuda::WarpAffine<T, 3>;
68+
case 4:
69+
return &ppl::cv::cuda::WarpAffine<T, 4>;
70+
default:
71+
MMDEPLOY_ERROR("unsupported channels {}", channels);
72+
return nullptr;
73+
}
74+
}
75+
76+
template <class T>
77+
Result<void> Dispatch(const Tensor& src, Tensor& dst, const float affine_matrix[6],
78+
cudaStream_t stream) {
79+
int h = (int)src.shape(1);
80+
int w = (int)src.shape(2);
81+
int c = (int)src.shape(3);
82+
int dst_h = (int)dst.shape(1);
83+
int dst_w = (int)dst.shape(2);
84+
85+
auto input = src.data<T>();
86+
auto output = dst.data<T>();
87+
88+
ppl::common::RetCode ret = 0;
89+
90+
if (auto warp_affine = Select<T>(c); warp_affine) {
91+
ret = warp_affine(stream, h, w, w * c, input, dst_h, dst_w, dst_w * c, output, affine_matrix,
92+
interp_, ppl::cv::BORDER_CONSTANT, 0);
93+
} else {
94+
return Status(eNotSupported);
95+
}
96+
97+
return ret == 0 ? success() : Result<void>(Status(eFail));
98+
}
99+
100+
ppl::cv::InterpolationType interp_;
101+
};
102+
103+
static auto Create(const string_view& interp) {
104+
ppl::cv::InterpolationType type{};
105+
if (interp == "bilinear") {
106+
type = ppl::cv::InterpolationType::INTERPOLATION_LINEAR;
107+
} else if (interp == "nearest") {
108+
type = ppl::cv::InterpolationType::INTERPOLATION_NEAREST_POINT;
109+
} else {
110+
MMDEPLOY_ERROR("unsupported interpolation method: {}", interp);
111+
throw_exception(eNotSupported);
112+
}
113+
return std::make_unique<WarpAffineImpl>(type);
114+
}
115+
116+
MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cuda, 0), Create);
117+
118+
} // namespace mmdeploy::operation::cuda

csrc/mmdeploy/operation/vision.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ MMDEPLOY_DEFINE_REGISTRY(HWC2CHW);
1212
MMDEPLOY_DEFINE_REGISTRY(Normalize);
1313
MMDEPLOY_DEFINE_REGISTRY(Crop);
1414
MMDEPLOY_DEFINE_REGISTRY(Flip);
15+
MMDEPLOY_DEFINE_REGISTRY(WarpAffine);
1516

1617
} // namespace mmdeploy::operation

csrc/mmdeploy/operation/vision.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ class Flip : public Operation {
7676
};
7777
MMDEPLOY_DECLARE_REGISTRY(Flip, unique_ptr<Flip>(int flip_code));
7878

79-
// TODO: warp affine
79+
// 2x3 OpenCV affine matrix, row major
80+
class WarpAffine : public Operation {
81+
public:
82+
virtual Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6],
83+
int dst_h, int dst_w) = 0;
84+
};
85+
MMDEPLOY_DECLARE_REGISTRY(WarpAffine, unique_ptr<WarpAffine>(const string_view& interp));
8086

8187
} // namespace mmdeploy::operation
8288

csrc/mmdeploy/preprocess/transform/load.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ class PrepareImage : public Transform {
4848

4949
Result<void> Apply(Value& data) override {
5050
MMDEPLOY_DEBUG("input: {}", data);
51+
52+
// early exit
53+
if (data.contains("img") && data["img"].is_any<Tensor>()) {
54+
return success();
55+
}
56+
5157
assert(data.contains("ori_img"));
5258

5359
Mat src_mat = data["ori_img"].get<Mat>();

0 commit comments

Comments
 (0)