Skip to content

Commit bcb93ea

Browse files
authored
[Enhancement] Add optional softmax in LinearClsHead (open-mmlab#1858)
* add softmax in cls postprocess * minor
1 parent f69c636 commit bcb93ea

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

csrc/mmdeploy/codebase/mmcls/linear_cls.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification {
1818
public:
1919
explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) {
2020
if (cfg.contains("params")) {
21+
softmax_ = cfg["params"].value("softmax", false);
2122
topk_ = cfg["params"].value("topk", 1);
2223
if (topk_ <= 0) {
2324
MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_);
@@ -54,17 +55,32 @@ class LinearClsHead : public MMClassification {
5455
iota(begin(idx), end(idx), 0);
5556
partial_sort(begin(idx), begin(idx) + topk, end(idx),
5657
[&](int i, int j) { return scores_data[i] > scores_data[j]; });
58+
59+
auto sum_exp = 0.f;
60+
std::vector<float> exp_scores;
61+
if (softmax_) {
62+
exp_scores.reserve(class_num);
63+
auto max_val = scores_data[idx[0]];
64+
for (int i = 0; i < class_num; ++i) {
65+
sum_exp += exp_scores.emplace_back(std::exp(scores_data[i] - max_val));
66+
}
67+
}
5768
for (int i = 0; i < topk; ++i) {
58-
auto label = Label{idx[i], scores_data[idx[i]]};
59-
MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score);
60-
output.push_back(label);
69+
float score = 0.f;
70+
if (softmax_) {
71+
score = exp_scores[idx[i]] / sum_exp;
72+
} else {
73+
score = scores_data[idx[i]];
74+
}
75+
output.push_back({idx[i], score});
6176
}
6277
return to_value(std::move(output));
6378
}
6479

6580
private:
6681
static constexpr const auto kHost = Device{0};
6782

83+
bool softmax_{false};
6884
int topk_{1};
6985
};
7086

0 commit comments

Comments
 (0)