Skip to content

Commit 67043c5

Browse files
lzhangzztriple-Mu
authored andcommitted
[Fix] Fix det_pose demo (open-mmlab#1419)
* fix det_pose demo * remove useless input
1 parent fa91f67 commit 67043c5

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

demo/csrc/c/det_pose.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ const auto config_json = R"(
2323
"type": "Inference",
2424
"input": "img",
2525
"output": "dets",
26-
"params": { "model": "../_detection_tmp_model" }
26+
"params": { "model": "TBD" }
2727
},
2828
{
2929
"type": "Task",
@@ -45,7 +45,7 @@ const auto config_json = R"(
4545
"type": "Inference",
4646
"input": "imgs_with_bboxes",
4747
"output": "keypoints",
48-
"params": { "model": "../posedet_tmp_model" }
48+
"params": { "model": "TBD" }
4949
}
5050
],
5151
"output": "*keypoints"
@@ -73,28 +73,36 @@ class AddBboxField {
7373
MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (AddBboxField, 0),
7474
[](const Value&) { return CreateTask(AddBboxField{}); });
7575

76-
class FilterBbox {
77-
public:
78-
Result<Value> operator()(const Value& dets) {
79-
Value::Array rets;
80-
for (const auto& det : dets) {
81-
if (det["label_id"].get<int>() == 0 && det["score"].get<float>() >= 0.3) {
82-
rets.push_back(det);
83-
}
76+
Result<Value> FilterBbox(const Value& dets) {
77+
Value::Array rets;
78+
for (const auto& det : dets) {
79+
if (det["label_id"].get<int>() == 0 && det["score"].get<float>() >= 0.3) {
80+
rets.push_back(det);
8481
}
85-
return rets;
8682
}
87-
};
83+
return rets;
84+
}
8885

8986
MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (FilterBbox, 0),
90-
[](const Value&) { return CreateTask(FilterBbox{}); });
87+
[](const Value&) { return CreateTask(FilterBbox); });
9188

9289
static std::vector<std::pair<int, int>> skeleton{
9390
{15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, {5, 6}, {5, 7}, {6, 8},
9491
{7, 9}, {8, 10}, {1, 2}, {0, 1}, {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}};
9592

96-
int main() {
93+
int main(int argc, char* argv[]) {
94+
if (argc != 5) {
95+
MMDEPLOY_INFO("usage: det_pose device det_model pose_model image");
96+
return 0;
97+
}
98+
const auto device_name = argv[1];
99+
const auto det_model_path = argv[2];
100+
const auto pose_model_path = argv[3];
101+
const auto image_path = argv[4];
102+
97103
auto config = from_json<Value>(config_json);
104+
config["tasks"][0]["params"]["model"] = det_model_path;
105+
config["tasks"][2]["tasks"][1]["params"]["model"] = pose_model_path;
98106

99107
mmdeploy_context_t context{};
100108
mmdeploy_context_create(&context);
@@ -105,17 +113,24 @@ int main() {
105113
mmdeploy_context_add(context, MMDEPLOY_TYPE_SCHEDULER, "net", single_thread);
106114
mmdeploy_context_add(context, MMDEPLOY_TYPE_SCHEDULER, "postprocess", thread_pool);
107115

116+
mmdeploy_device_t device{};
117+
mmdeploy_device_create(device_name, 0, &device);
118+
mmdeploy_context_add(context, MMDEPLOY_TYPE_DEVICE, nullptr, device);
119+
108120
mmdeploy_pipeline_t pipeline{};
109121
if (auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &pipeline)) {
110122
MMDEPLOY_ERROR("failed to create pipeline: {}", ec);
111123
return -1;
112124
}
113125

114-
cv::Mat mat = cv::imread("../ezgif-5-6ec14aca55.jpg");
126+
cv::Mat mat = cv::imread(image_path);
127+
if (!mat.data) {
128+
MMDEPLOY_ERROR("invalid image path: {}", image_path);
129+
}
115130
framework::Mat img(mat.rows, mat.cols, PixelFormat::kBGR, DataType::kINT8, mat.data,
116131
framework::Device(0));
117132

118-
Value input = Value::Array{Value::Array{Value::Object{{"ori_img", img}}}};
133+
Value input{{{{"ori_img", img}}}};
119134

120135
mmdeploy_value_t tmp{};
121136
mmdeploy_pipeline_apply(pipeline, (mmdeploy_value_t)&input, &tmp);
@@ -125,16 +140,16 @@ int main() {
125140
mmdeploy_detector_get_result(tmp, &dets, &det_count);
126141

127142
auto output = std::move(*(Value*)tmp);
143+
mmdeploy_value_destroy(tmp);
144+
145+
// result of second output
146+
auto& pose = output[1];
128147

129148
mmdeploy_pose_detection_t* kps{};
130-
Value pose;
131-
pose.push_back(output[1]);
132149
mmdeploy_pose_detector_get_result((mmdeploy_value_t)&pose, &kps);
133150

134151
MMDEPLOY_INFO("{}", *det_count);
135152

136-
mmdeploy_value_destroy(tmp);
137-
138153
for (int i = 0; i < *det_count; ++i) {
139154
if (dets[i].label_id != 0 || dets[i].score < 0.3) {
140155
continue;
@@ -156,6 +171,8 @@ int main() {
156171
}
157172
}
158173

174+
mmdeploy_pose_detector_release_result(kps, pose.size());
175+
159176
cv::imwrite("output_det_pose.jpg", mat);
160177

161178
mmdeploy_pipeline_destroy(pipeline);

0 commit comments

Comments
 (0)