@@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification {
18
18
public:
19
19
explicit LinearClsHead (const Value& cfg) : MMClassification(cfg) {
20
20
if (cfg.contains (" params" )) {
21
+ softmax_ = cfg[" params" ].value (" softmax" , false );
21
22
topk_ = cfg[" params" ].value (" topk" , 1 );
22
23
if (topk_ <= 0 ) {
23
24
MMDEPLOY_ERROR (" 'topk' should be greater than 0, but got '{}'" , topk_);
@@ -54,17 +55,32 @@ class LinearClsHead : public MMClassification {
54
55
iota (begin (idx), end (idx), 0 );
55
56
partial_sort (begin (idx), begin (idx) + topk, end (idx),
56
57
[&](int i, int j) { return scores_data[i] > scores_data[j]; });
58
+
59
+ auto sum_exp = 0 .f ;
60
+ std::vector<float > exp_scores;
61
+ if (softmax_) {
62
+ exp_scores.reserve (class_num);
63
+ auto max_val = scores_data[idx[0 ]];
64
+ for (int i = 0 ; i < class_num; ++i) {
65
+ sum_exp += exp_scores.emplace_back (std::exp (scores_data[i] - max_val));
66
+ }
67
+ }
57
68
for (int i = 0 ; i < topk; ++i) {
58
- auto label = Label{idx[i], scores_data[idx[i]]};
59
- MMDEPLOY_DEBUG (" label_id: {}, score: {}" , label.label_id , label.score );
60
- output.push_back (label);
69
+ float score = 0 .f ;
70
+ if (softmax_) {
71
+ score = exp_scores[idx[i]] / sum_exp;
72
+ } else {
73
+ score = scores_data[idx[i]];
74
+ }
75
+ output.push_back ({idx[i], score});
61
76
}
62
77
return to_value (std::move (output));
63
78
}
64
79
65
80
private:
66
81
static constexpr const auto kHost = Device{0 };
67
82
83
+ bool softmax_{false };
68
84
int topk_{1 };
69
85
};
70
86
0 commit comments