Skip to content

Commit 3309bc4

Browse files
lzhangzzlvhan028
authored andcommitted
[Enhancement] Optimize C++ demos (#1715)
* optimize demos * show text in image * optimize demos * fix minor * fix minor * fix minor * install utils & fix demo file extensions * rename * parse empty flags * antialias * handle video complications (cherry picked from commit 2b18596)
1 parent 98b0228 commit 3309bc4

20 files changed

+1481
-525
lines changed

csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ typedef struct mmdeploy_pose_tracker_param_t {
3636
int32_t pose_max_num_bboxes;
3737
// threshold for visible key-points, default = 0.5
3838
float pose_kpt_thr;
39-
// min number of key-points for valid poses, default = -1
39+
// min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1
4040
int32_t pose_min_keypoints;
4141
// scale for expanding key-points to bbox, default = 1.25
4242
float pose_bbox_scale;

csrc/mmdeploy/apis/cxx/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/common.hpp
2424
install(DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ DESTINATION example/cpp
2525
FILES_MATCHING
2626
PATTERN "*.cxx"
27+
PATTERN "*.h"
2728
)

csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using Points = vector<cv::Point2f>;
2222
using Score = float;
2323
using Scores = vector<float>;
2424

25-
#define POSE_TRACKER_DEBUG(...) MMDEPLOY_INFO(__VA_ARGS__)
25+
#define POSE_TRACKER_DEBUG(...) MMDEPLOY_DEBUG(__VA_ARGS__)
2626

2727
// opencv3 can't construct cv::Mat from std::array
2828
template <size_t N>

demo/csrc/cpp/classifier.cxx

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
11

22
#include "mmdeploy/classifier.hpp"
33

4-
#include <string>
5-
64
#include "opencv2/imgcodecs/imgcodecs.hpp"
5+
#include "utils/argparse.h"
6+
#include "utils/visualize.h"
7+
8+
DEFINE_ARG_string(model, "Model path");
9+
DEFINE_ARG_string(image, "Input image path");
10+
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
11+
DEFINE_string(output, "classifier_output.jpg", "Output image path");
712

813
int main(int argc, char* argv[]) {
9-
if (argc != 4) {
10-
fprintf(stderr, "usage:\n image_classification device_name model_path image_path\n");
11-
return 1;
14+
if (!utils::ParseArguments(argc, argv)) {
15+
return -1;
1216
}
13-
auto device_name = argv[1];
14-
auto model_path = argv[2];
15-
auto image_path = argv[3];
16-
cv::Mat img = cv::imread(image_path);
17-
if (!img.data) {
18-
fprintf(stderr, "failed to load image: %s\n", image_path);
19-
return 1;
17+
18+
cv::Mat img = cv::imread(ARGS_image);
19+
if (img.empty()) {
20+
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
21+
return -1;
2022
}
2123

22-
mmdeploy::Model model(model_path);
23-
mmdeploy::Classifier classifier(model, mmdeploy::Device{device_name, 0});
24+
// construct a classifier instance
25+
mmdeploy::Classifier classifier(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device});
2426

25-
auto res = classifier.Apply(img);
27+
// apply the classifier; the result is an array-like class holding references to
28+
// `mmdeploy_classification_t`, will be released automatically on destruction
29+
mmdeploy::Classifier::Result result = classifier.Apply(img);
30+
31+
// visualize results
32+
utils::Visualize v;
33+
auto sess = v.get_session(img);
34+
int count = 0;
35+
for (const mmdeploy_classification_t& cls : result) {
36+
sess.add_label(cls.label_id, cls.score, count++);
37+
}
2638

27-
for (const auto& cls : res) {
28-
fprintf(stderr, "label: %d, score: %.4f\n", cls.label_id, cls.score);
39+
if (!FLAGS_output.empty()) {
40+
cv::imwrite(FLAGS_output, sess.get());
2941
}
3042

3143
return 0;

demo/csrc/cpp/det_pose.cpp

Lines changed: 0 additions & 113 deletions
This file was deleted.

demo/csrc/cpp/det_pose.cxx

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include <iostream>
4+
5+
#include "mmdeploy/detector.hpp"
6+
#include "mmdeploy/pose_detector.hpp"
7+
#include "opencv2/imgcodecs/imgcodecs.hpp"
8+
#include "utils/argparse.h"
9+
#include "utils/visualize.h"
10+
11+
DEFINE_ARG_string(det_model, "Object detection model path");
12+
DEFINE_ARG_string(pose_model, "Pose estimation model path");
13+
DEFINE_ARG_string(image, "Input image path");
14+
15+
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
16+
DEFINE_string(output, "det_pose_output.jpg", "Output image path");
17+
DEFINE_string(skeleton, "coco", R"(Path to skeleton data or name of predefined skeletons: "coco")");
18+
19+
DEFINE_int32(det_label, 0, "Detection label use for pose estimation");
20+
DEFINE_double(det_thr, .5, "Detection score threshold");
21+
DEFINE_double(det_min_bbox_size, -1, "Detection minimum bbox size");
22+
23+
DEFINE_double(pose_thr, 0, "Pose key-point threshold");
24+
25+
int main(int argc, char* argv[]) {
26+
if (!utils::ParseArguments(argc, argv)) {
27+
return -1;
28+
}
29+
30+
cv::Mat img = cv::imread(ARGS_image);
31+
if (img.empty()) {
32+
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
33+
return -1;
34+
}
35+
36+
mmdeploy::Device device{FLAGS_device};
37+
// create object detector
38+
mmdeploy::Detector detector(mmdeploy::Model(ARGS_det_model), device);
39+
// create pose detector
40+
mmdeploy::PoseDetector pose(mmdeploy::Model(ARGS_pose_model), device);
41+
42+
// apply the detector, the result is an array-like class holding references to
43+
// `mmdeploy_detection_t`, will be released automatically on destruction
44+
mmdeploy::Detector::Result dets = detector.Apply(img);
45+
46+
// filter detections and extract bboxes for pose model
47+
std::vector<mmdeploy_rect_t> bboxes;
48+
for (const mmdeploy_detection_t& det : dets) {
49+
if (det.label_id == FLAGS_det_label && det.score > FLAGS_det_thr) {
50+
bboxes.push_back(det.bbox);
51+
}
52+
}
53+
54+
// apply pose detector, if no bboxes are provided, full image will be used; the result is an
55+
// array-like class holding references to `mmdeploy_pose_detection_t`, will be released
56+
// automatically on destruction
57+
mmdeploy::PoseDetector::Result poses = pose.Apply(img, bboxes);
58+
59+
assert(bboxes.size() == poses.size());
60+
61+
// visualize results
62+
utils::Visualize v;
63+
v.set_skeleton(utils::Skeleton::get(FLAGS_skeleton));
64+
auto sess = v.get_session(img);
65+
for (size_t i = 0; i < bboxes.size(); ++i) {
66+
sess.add_bbox(bboxes[i], -1, -1);
67+
sess.add_pose(poses[i].point, poses[i].score, poses[i].length, FLAGS_pose_thr);
68+
}
69+
70+
if (!FLAGS_output.empty()) {
71+
cv::imwrite(FLAGS_output, sess.get());
72+
}
73+
74+
return 0;
75+
}

demo/csrc/cpp/detector.cxx

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,47 @@
11
#include "mmdeploy/detector.hpp"
22

3-
#include <opencv2/imgcodecs/imgcodecs.hpp>
4-
#include <opencv2/imgproc/imgproc.hpp>
5-
#include <string>
3+
#include "opencv2/imgcodecs/imgcodecs.hpp"
4+
#include "utils/argparse.h"
5+
#include "utils/visualize.h"
66

7-
int main(int argc, char* argv[]) {
8-
if (argc != 4) {
9-
fprintf(stderr, "usage:\n object_detection device_name model_path image_path\n");
10-
return 1;
11-
}
12-
auto device_name = argv[1];
13-
auto model_path = argv[2];
14-
auto image_path = argv[3];
15-
cv::Mat img = cv::imread(image_path);
16-
if (!img.data) {
17-
fprintf(stderr, "failed to load image: %s\n", image_path);
18-
return 1;
19-
}
20-
21-
mmdeploy::Model model(model_path);
22-
mmdeploy::Detector detector(model, mmdeploy::Device{device_name, 0});
23-
24-
auto dets = detector.Apply(img);
25-
26-
fprintf(stdout, "bbox_count=%d\n", (int)dets.size());
27-
28-
for (int i = 0; i < dets.size(); ++i) {
29-
const auto& box = dets[i].bbox;
30-
const auto& mask = dets[i].mask;
7+
DEFINE_ARG_string(model, "Model path");
8+
DEFINE_ARG_string(image, "Input image path");
9+
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
10+
DEFINE_string(output, "detector_output.jpg", "Output image path");
3111

32-
fprintf(stdout, "box %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n",
33-
i, box.left, box.top, box.right, box.bottom, dets[i].label_id, dets[i].score);
12+
DEFINE_double(det_thr, .5, "Detection score threshold");
3413

35-
// skip detections with invalid bbox size (bbox height or width < 1)
36-
if ((box.right - box.left) < 1 || (box.bottom - box.top) < 1) {
37-
continue;
38-
}
14+
int main(int argc, char* argv[]) {
15+
if (!utils::ParseArguments(argc, argv)) {
16+
return -1;
17+
}
3918

40-
// skip detections less than specified score threshold
41-
if (dets[i].score < 0.3) {
42-
continue;
43-
}
19+
cv::Mat img = cv::imread(ARGS_image);
20+
if (img.empty()) {
21+
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
22+
return -1;
23+
}
4424

45-
// generate mask overlay if model exports masks
46-
if (mask != nullptr) {
47-
fprintf(stdout, "mask %d, height=%d, width=%d\n", i, mask->height, mask->width);
25+
// construct a detector instance
26+
mmdeploy::Detector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device});
4827

49-
cv::Mat imgMask(mask->height, mask->width, CV_8UC1, &mask->data[0]);
50-
auto x0 = std::max(std::floor(box.left) - 1, 0.f);
51-
auto y0 = std::max(std::floor(box.top) - 1, 0.f);
52-
cv::Rect roi((int)x0, (int)y0, mask->width, mask->height);
28+
// apply the detector, the result is an array-like class holding references to
29+
// `mmdeploy_detection_t`, will be released automatically on destruction
30+
mmdeploy::Detector::Result dets = detector.Apply(img);
5331

54-
// split the RGB channels, overlay mask to a specific color channel
55-
cv::Mat ch[3];
56-
split(img, ch);
57-
int col = 0; // int col = i % 3;
58-
cv::bitwise_or(imgMask, ch[col](roi), ch[col](roi));
59-
merge(ch, 3, img);
32+
// visualize
33+
utils::Visualize v;
34+
auto sess = v.get_session(img);
35+
int count = 0;
36+
for (const mmdeploy_detection_t& det : dets) {
37+
if (det.score > FLAGS_det_thr) { // filter bboxes
38+
sess.add_det(det.bbox, det.label_id, det.score, det.mask, count++);
6039
}
61-
62-
cv::rectangle(img, cv::Point{(int)box.left, (int)box.top},
63-
cv::Point{(int)box.right, (int)box.bottom}, cv::Scalar{0, 255, 0});
6440
}
6541

66-
cv::imwrite("output_detection.png", img);
42+
if (!FLAGS_output.empty()) {
43+
cv::imwrite(FLAGS_output, sess.get());
44+
}
6745

6846
return 0;
6947
}

0 commit comments

Comments
 (0)