|
1 | 1 | #include "mmdeploy/detector.hpp"
|
2 | 2 |
|
3 |
| -#include <opencv2/imgcodecs/imgcodecs.hpp> |
4 |
| -#include <opencv2/imgproc/imgproc.hpp> |
5 |
| -#include <string> |
| 3 | +#include "opencv2/imgcodecs/imgcodecs.hpp" |
| 4 | +#include "utils/argparse.h" |
| 5 | +#include "utils/visualize.h" |
6 | 6 |
|
7 |
| -int main(int argc, char* argv[]) { |
8 |
| - if (argc != 4) { |
9 |
| - fprintf(stderr, "usage:\n object_detection device_name model_path image_path\n"); |
10 |
| - return 1; |
11 |
| - } |
12 |
| - auto device_name = argv[1]; |
13 |
| - auto model_path = argv[2]; |
14 |
| - auto image_path = argv[3]; |
15 |
| - cv::Mat img = cv::imread(image_path); |
16 |
| - if (!img.data) { |
17 |
| - fprintf(stderr, "failed to load image: %s\n", image_path); |
18 |
| - return 1; |
19 |
| - } |
20 |
| - |
21 |
| - mmdeploy::Model model(model_path); |
22 |
| - mmdeploy::Detector detector(model, mmdeploy::Device{device_name, 0}); |
23 |
| - |
24 |
| - auto dets = detector.Apply(img); |
25 |
| - |
26 |
| - fprintf(stdout, "bbox_count=%d\n", (int)dets.size()); |
27 |
| - |
28 |
| - for (int i = 0; i < dets.size(); ++i) { |
29 |
| - const auto& box = dets[i].bbox; |
30 |
| - const auto& mask = dets[i].mask; |
| 7 | +DEFINE_ARG_string(model, "Model path"); |
| 8 | +DEFINE_ARG_string(image, "Input image path"); |
| 9 | +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); |
| 10 | +DEFINE_string(output, "detector_output.jpg", "Output image path"); |
31 | 11 |
|
32 |
| - fprintf(stdout, "box %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n", |
33 |
| - i, box.left, box.top, box.right, box.bottom, dets[i].label_id, dets[i].score); |
| 12 | +DEFINE_double(det_thr, .5, "Detection score threshold"); |
34 | 13 |
|
35 |
| - // skip detections with invalid bbox size (bbox height or width < 1) |
36 |
| - if ((box.right - box.left) < 1 || (box.bottom - box.top) < 1) { |
37 |
| - continue; |
38 |
| - } |
| 14 | +int main(int argc, char* argv[]) { |
| 15 | + if (!utils::ParseArguments(argc, argv)) { |
| 16 | + return -1; |
| 17 | + } |
39 | 18 |
|
40 |
| - // skip detections less than specified score threshold |
41 |
| - if (dets[i].score < 0.3) { |
42 |
| - continue; |
43 |
| - } |
| 19 | + cv::Mat img = cv::imread(ARGS_image); |
| 20 | + if (img.empty()) { |
| 21 | + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); |
| 22 | + return -1; |
| 23 | + } |
44 | 24 |
|
45 |
| - // generate mask overlay if model exports masks |
46 |
| - if (mask != nullptr) { |
47 |
| - fprintf(stdout, "mask %d, height=%d, width=%d\n", i, mask->height, mask->width); |
| 25 | + // construct a detector instance |
| 26 | + mmdeploy::Detector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}); |
48 | 27 |
|
49 |
| - cv::Mat imgMask(mask->height, mask->width, CV_8UC1, &mask->data[0]); |
50 |
| - auto x0 = std::max(std::floor(box.left) - 1, 0.f); |
51 |
| - auto y0 = std::max(std::floor(box.top) - 1, 0.f); |
52 |
| - cv::Rect roi((int)x0, (int)y0, mask->width, mask->height); |
| 28 | + // apply the detector, the result is an array-like class holding references to |
| 29 | + // `mmdeploy_detection_t`, will be released automatically on destruction |
| 30 | + mmdeploy::Detector::Result dets = detector.Apply(img); |
53 | 31 |
|
54 |
| - // split the RGB channels, overlay mask to a specific color channel |
55 |
| - cv::Mat ch[3]; |
56 |
| - split(img, ch); |
57 |
| - int col = 0; // int col = i % 3; |
58 |
| - cv::bitwise_or(imgMask, ch[col](roi), ch[col](roi)); |
59 |
| - merge(ch, 3, img); |
| 32 | + // visualize |
| 33 | + utils::Visualize v; |
| 34 | + auto sess = v.get_session(img); |
| 35 | + int count = 0; |
| 36 | + for (const mmdeploy_detection_t& det : dets) { |
| 37 | + if (det.score > FLAGS_det_thr) { // filter bboxes |
| 38 | + sess.add_det(det.bbox, det.label_id, det.score, det.mask, count++); |
60 | 39 | }
|
61 |
| - |
62 |
| - cv::rectangle(img, cv::Point{(int)box.left, (int)box.top}, |
63 |
| - cv::Point{(int)box.right, (int)box.bottom}, cv::Scalar{0, 255, 0}); |
64 | 40 | }
|
65 | 41 |
|
66 |
| - cv::imwrite("output_detection.png", img); |
| 42 | + if (!FLAGS_output.empty()) { |
| 43 | + cv::imwrite(FLAGS_output, sess.get()); |
| 44 | + } |
67 | 45 |
|
68 | 46 | return 0;
|
69 | 47 | }
|
0 commit comments