diff --git a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp index 2f42edd09b..3e1f171e3f 100644 --- a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp +++ b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp @@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification { public: explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) { if (cfg.contains("params")) { + softmax_ = cfg["params"].value("softmax", false); topk_ = cfg["params"].value("topk", 1); if (topk_ <= 0) { MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_); @@ -54,10 +55,24 @@ class LinearClsHead : public MMClassification { iota(begin(idx), end(idx), 0); partial_sort(begin(idx), begin(idx) + topk, end(idx), [&](int i, int j) { return scores_data[i] > scores_data[j]; }); + + auto sum_exp = 0.f; + std::vector exp_scores; + if (softmax_) { + exp_scores.reserve(class_num); + auto max_val = scores_data[idx[0]]; + for (int i = 0; i < class_num; ++i) { + sum_exp += exp_scores.emplace_back(std::exp(scores_data[i] - max_val)); + } + } for (int i = 0; i < topk; ++i) { - auto label = Label{idx[i], scores_data[idx[i]]}; - MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score); - output.push_back(label); + float score = 0.f; + if (softmax_) { + score = exp_scores[idx[i]] / sum_exp; + } else { + score = scores_data[idx[i]]; + } + output.push_back({idx[i], score}); } return to_value(std::move(output)); } @@ -65,6 +80,7 @@ class LinearClsHead : public MMClassification { private: static constexpr const auto kHost = Device{0}; + bool softmax_{false}; int topk_{1}; };