@@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification {
18
18
public:
19
19
explicit LinearClsHead (const Value& cfg) : MMClassification(cfg) {
20
20
if (cfg.contains (" params" )) {
21
+ softmax_ = cfg[" params" ].value (" softmax" , false );
21
22
topk_ = cfg[" params" ].value (" topk" , 1 );
22
23
if (topk_ <= 0 ) {
23
24
MMDEPLOY_ERROR (" 'topk' should be greater than 0, but got '{}'" , topk_);
@@ -47,23 +48,39 @@ class LinearClsHead : public MMClassification {
47
48
private:
48
49
Value GetLabels (const Tensor& scores, int class_num) const {
49
50
auto scores_data = scores.data <float >();
51
+ auto topk = std::min (topk_, class_num);
50
52
Labels output;
51
- output.reserve (topk_ );
53
+ output.reserve (topk );
52
54
std::vector<int > idx (class_num);
53
55
iota (begin (idx), end (idx), 0 );
54
- partial_sort (begin (idx), begin (idx) + topk_ , end (idx),
56
+ partial_sort (begin (idx), begin (idx) + topk , end (idx),
55
57
[&](int i, int j) { return scores_data[i] > scores_data[j]; });
56
- for (int i = 0 ; i < topk_; ++i) {
57
- auto label = Label{idx[i], scores_data[idx[i]]};
58
- MMDEPLOY_DEBUG (" label_id: {}, score: {}" , label.label_id , label.score );
59
- output.push_back (label);
58
+
59
+ auto sum_exp = 0 .f ;
60
+ std::vector<float > exp_scores;
61
+ if (softmax_) {
62
+ exp_scores.reserve (class_num);
63
+ auto max_val = scores_data[idx[0 ]];
64
+ for (int i = 0 ; i < class_num; ++i) {
65
+ sum_exp += exp_scores.emplace_back (std::exp (scores_data[i] - max_val));
66
+ }
67
+ }
68
+ for (int i = 0 ; i < topk; ++i) {
69
+ float score = 0 .f ;
70
+ if (softmax_) {
71
+ score = exp_scores[idx[i]] / sum_exp;
72
+ } else {
73
+ score = scores_data[idx[i]];
74
+ }
75
+ output.push_back ({idx[i], score});
60
76
}
61
77
return to_value (std::move (output));
62
78
}
63
79
64
80
private:
65
81
static constexpr const auto kHost = Device{0 };
66
82
83
+ bool softmax_{false };
67
84
int topk_{1 };
68
85
};
69
86
0 commit comments