Skip to content

Commit 30a74da

Browse files
committed
[Enhancement] Add optional softmax in LinearClsHead (open-mmlab#1858)
* add softmax in cls postprocess * minor (cherry picked from commit bcb93ea)
1 parent bace0a2 commit 30a74da

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

csrc/mmdeploy/codebase/mmcls/linear_cls.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification {
1818
public:
1919
explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) {
2020
if (cfg.contains("params")) {
21+
softmax_ = cfg["params"].value("softmax", false);
2122
topk_ = cfg["params"].value("topk", 1);
2223
if (topk_ <= 0) {
2324
MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_);
@@ -47,23 +48,39 @@ class LinearClsHead : public MMClassification {
4748
private:
4849
Value GetLabels(const Tensor& scores, int class_num) const {
4950
auto scores_data = scores.data<float>();
51+
auto topk = std::min(topk_, class_num);
5052
Labels output;
51-
output.reserve(topk_);
53+
output.reserve(topk);
5254
std::vector<int> idx(class_num);
5355
iota(begin(idx), end(idx), 0);
54-
partial_sort(begin(idx), begin(idx) + topk_, end(idx),
56+
partial_sort(begin(idx), begin(idx) + topk, end(idx),
5557
[&](int i, int j) { return scores_data[i] > scores_data[j]; });
56-
for (int i = 0; i < topk_; ++i) {
57-
auto label = Label{idx[i], scores_data[idx[i]]};
58-
MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score);
59-
output.push_back(label);
58+
59+
auto sum_exp = 0.f;
60+
std::vector<float> exp_scores;
61+
if (softmax_) {
62+
exp_scores.reserve(class_num);
63+
auto max_val = scores_data[idx[0]];
64+
for (int i = 0; i < class_num; ++i) {
65+
sum_exp += exp_scores.emplace_back(std::exp(scores_data[i] - max_val));
66+
}
67+
}
68+
for (int i = 0; i < topk; ++i) {
69+
float score = 0.f;
70+
if (softmax_) {
71+
score = exp_scores[idx[i]] / sum_exp;
72+
} else {
73+
score = scores_data[idx[i]];
74+
}
75+
output.push_back({idx[i], score});
6076
}
6177
return to_value(std::move(output));
6278
}
6379

6480
private:
6581
static constexpr const auto kHost = Device{0};
6682

83+
bool softmax_{false};
6784
int topk_{1};
6885
};
6986

0 commit comments

Comments
 (0)