From 694637349819272e64d1ef4e59c9d849bffaa8dd Mon Sep 17 00:00:00 2001 From: zhangli Date: Wed, 8 Mar 2023 19:55:11 +0800 Subject: [PATCH 1/2] add softmax in cls postprocess --- csrc/mmdeploy/codebase/mmcls/linear_cls.cpp | 23 ++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp index 2f42edd09b..2ebb1ba019 100644 --- a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp +++ b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp @@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification { public: explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) { if (cfg.contains("params")) { + softmax_ = cfg["params"].value("softmax", false); topk_ = cfg["params"].value("topk", 1); if (topk_ <= 0) { MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_); @@ -45,6 +46,7 @@ class LinearClsHead : public MMClassification { } private: + Value GetLabels(const Tensor& scores, int class_num) const { auto scores_data = scores.data(); auto topk = std::min(topk_, class_num); @@ -54,10 +56,24 @@ class LinearClsHead : public MMClassification { iota(begin(idx), end(idx), 0); partial_sort(begin(idx), begin(idx) + topk, end(idx), [&](int i, int j) { return scores_data[i] > scores_data[j]; }); + + auto sum_exp = 0.f; + std::vector exp_scores; + if (softmax_) { + exp_scores.reserve(class_num); + auto max_val = scores_data[idx[0]]; + for (int i = 0; i < class_num; ++i) { + sum_exp += exp_scores.emplace_back(std::exp(scores_data[i] - max_val)); + } + } for (int i = 0; i < topk; ++i) { - auto label = Label{idx[i], scores_data[idx[i]]}; - MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score); - output.push_back(label); + float score = 0.f; + if (softmax_) { + score = exp_scores[idx[i]] / sum_exp; + } else { + score = scores_data[idx[i]]; + } + output.push_back({idx[i], score}); } return to_value(std::move(output)); } @@ -65,6 +81,7 @@ class LinearClsHead : public MMClassification { private: static constexpr const auto kHost = Device{0}; + bool softmax_{false}; int topk_{1}; }; From 116317776911f5361f44d354624757f09b0ff804 Mon Sep 17 00:00:00 2001 From: zhangli Date: Wed, 8 Mar 2023 21:26:41 +0800 Subject: [PATCH 2/2] minor --- csrc/mmdeploy/codebase/mmcls/linear_cls.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp index 2ebb1ba019..3e1f171e3f 100644 --- a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp +++ b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp @@ -46,7 +46,6 @@ class LinearClsHead : public MMClassification { } private: - Value GetLabels(const Tensor& scores, int class_num) const { auto scores_data = scores.data(); auto topk = std::min(topk_, class_num);