Skip to content

Commit 9c9d09e

Browse files
author
LittleMouse
committed
FireRedASR2 supported axera backend
1 parent 25e01ae commit 9c9d09e

5 files changed

Lines changed: 415 additions & 6 deletions

File tree

sherpa-onnx/csrc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ endif()
208208
if(SHERPA_ONNX_ENABLE_AXERA)
209209
list(APPEND sources
210210
./axera/ax-engine-guard.cc
211+
./axera/offline-fire-red-asr-ctc-model-axera.cc
211212
./axera/offline-sense-voice-model-axera.cc
212213
./axera/utils.cc
213214
)
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
// sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.cc
2+
//
3+
// Copyright (c) 2026 Xiaomi Corporation
4+
5+
#include "sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.h"
6+
7+
#include <algorithm>
8+
#include <array>
9+
#include <cstdint>
10+
#include <cstring>
11+
#include <mutex>
12+
#include <utility>
13+
#include <vector>
14+
15+
#if __ANDROID_API__ >= 9
16+
#include "android/asset_manager.h"
17+
#include "android/asset_manager_jni.h"
18+
#endif
19+
20+
#if __OHOS__
21+
#include "rawfile/raw_file_manager.h"
22+
#endif
23+
24+
#include "Eigen/Dense"
25+
#include "ax_engine_api.h" // NOLINT
26+
#include "sherpa-onnx/csrc/axera/ax-engine-guard.h"
27+
#include "sherpa-onnx/csrc/axera/utils.h"
28+
#include "sherpa-onnx/csrc/file-utils.h"
29+
#include "sherpa-onnx/csrc/macros.h"
30+
31+
namespace sherpa_onnx {
32+
33+
class OfflineFireRedAsrCtcModelAxera::Impl {
34+
public:
35+
~Impl() {
36+
FreeIO(&io_data_);
37+
if (handle_) {
38+
AX_ENGINE_DestroyHandle(handle_);
39+
}
40+
}
41+
42+
explicit Impl(const OfflineModelConfig &config)
43+
: config_(config), allocator_{} {
44+
auto buf = ReadFile(config_.fire_red_asr_ctc.model);
45+
Init(buf.data(), buf.size());
46+
}
47+
48+
template <typename Manager>
49+
Impl(Manager *mgr, const OfflineModelConfig &config)
50+
: config_(config), allocator_{} {
51+
auto buf = ReadFile(mgr, config_.fire_red_asr_ctc.model);
52+
Init(buf.data(), buf.size());
53+
}
54+
55+
std::vector<Ort::Value> Forward(Ort::Value features,
56+
Ort::Value features_length) {
57+
std::lock_guard<std::mutex> lock(mutex_);
58+
59+
auto features_shape = features.GetTensorTypeAndShapeInfo().GetShape();
60+
int32_t batch_size = features_shape[0];
61+
int32_t num_frames = features_shape[1];
62+
int32_t feat_dim = features_shape[2];
63+
64+
const float *p_features = features.GetTensorData<float>();
65+
const int64_t *p_features_length = features_length.GetTensorData<int64_t>();
66+
67+
if (batch_size != 1) {
68+
SHERPA_ONNX_LOGE("Only batch size 1 is supported by axera. Given: %d",
69+
batch_size);
70+
SHERPA_ONNX_EXIT(-1);
71+
}
72+
73+
int32_t expected_frames = io_info_->pInputs[0].pShape[1];
74+
75+
int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames);
76+
valid_frames = std::min<int32_t>(valid_frames,
77+
static_cast<int32_t>(p_features_length[0]));
78+
79+
std::vector<float> padded_features(expected_frames * feat_dim, 0.0f);
80+
std::copy(p_features, p_features + valid_frames * feat_dim,
81+
padded_features.begin());
82+
83+
std::vector<int32_t> speech_length = {valid_frames};
84+
85+
const auto &in0_meta = io_info_->pInputs[0];
86+
size_t bytes0 = in0_meta.nSize;
87+
if (bytes0 != padded_features.size() * sizeof(float)) {
88+
SHERPA_ONNX_LOGE(
89+
"Feature size mismatch. model expects %u bytes, but got %zu bytes",
90+
in0_meta.nSize, padded_features.size() * sizeof(float));
91+
SHERPA_ONNX_EXIT(-1);
92+
}
93+
94+
std::memcpy(io_data_.pInputs[0].pVirAddr, padded_features.data(), bytes0);
95+
96+
const auto &in1_meta = io_info_->pInputs[1];
97+
size_t bytes1 = in1_meta.nSize;
98+
if (bytes1 != speech_length.size() * sizeof(int32_t)) {
99+
SHERPA_ONNX_LOGE(
100+
"Speech length size mismatch. model expects %u bytes, but got %zu "
101+
"bytes",
102+
in1_meta.nSize, speech_length.size() * sizeof(int32_t));
103+
SHERPA_ONNX_EXIT(-1);
104+
}
105+
106+
std::memcpy(io_data_.pInputs[1].pVirAddr, speech_length.data(), bytes1);
107+
108+
auto ret = AX_ENGINE_RunSync(handle_, &io_data_);
109+
if (ret != 0) {
110+
SHERPA_ONNX_LOGE("AX_ENGINE_RunSync failed, ret = %d", ret);
111+
SHERPA_ONNX_EXIT(-1);
112+
}
113+
114+
const auto &out0_meta = io_info_->pOutputs[0];
115+
const auto &out0_buf = io_data_.pOutputs[0];
116+
117+
int32_t out_frames = out0_meta.pShape[1];
118+
int32_t vocab_size = out0_meta.pShape[2];
119+
120+
std::array<int64_t, 3> logits_shape = {1, out_frames, vocab_size};
121+
Ort::Value logits = Ort::Value::CreateTensor<float>(
122+
allocator_, logits_shape.data(), logits_shape.size());
123+
124+
float *p_logits = logits.GetTensorMutableData<float>();
125+
std::memcpy(p_logits, out0_buf.pVirAddr, out0_meta.nSize);
126+
127+
const auto &out1_meta = io_info_->pOutputs[1];
128+
const auto &out1_buf = io_data_.pOutputs[1];
129+
130+
int64_t out_length = 0;
131+
if (out1_meta.eDataType == AX_ENGINE_DT_SINT32) {
132+
out_length = static_cast<int64_t>(
133+
reinterpret_cast<const int32_t *>(out1_buf.pVirAddr)[0]);
134+
} else if (out1_meta.eDataType == AX_ENGINE_DT_UINT32) {
135+
out_length = static_cast<int64_t>(
136+
reinterpret_cast<const uint32_t *>(out1_buf.pVirAddr)[0]);
137+
} else if (out1_meta.eDataType == AX_ENGINE_DT_FLOAT32) {
138+
out_length = static_cast<int64_t>(
139+
reinterpret_cast<const float *>(out1_buf.pVirAddr)[0]);
140+
} else {
141+
SHERPA_ONNX_LOGE("Unsupported length output dtype: %d",
142+
static_cast<int32_t>(out1_meta.eDataType));
143+
SHERPA_ONNX_EXIT(-1);
144+
}
145+
146+
std::array<int64_t, 1> lengths_shape = {1};
147+
Ort::Value lengths = Ort::Value::CreateTensor<int64_t>(
148+
allocator_, lengths_shape.data(), lengths_shape.size());
149+
150+
int64_t *p_lengths = lengths.GetTensorMutableData<int64_t>();
151+
p_lengths[0] = out_length;
152+
153+
std::vector<Ort::Value> ans;
154+
ans.push_back(std::move(logits));
155+
ans.push_back(std::move(lengths));
156+
157+
return ans;
158+
}
159+
160+
int32_t VocabSize() const { return vocab_size_; }
161+
162+
int32_t SubsamplingFactor() const { return subsampling_factor_; }
163+
164+
OrtAllocator *Allocator() { return allocator_; }
165+
166+
void NormalizeFeatures(float *features, int32_t num_frames,
167+
int32_t feat_dim) const {
168+
if (static_cast<int32_t>(mean_.size()) != feat_dim) {
169+
SHERPA_ONNX_LOGE("Bad things happened");
170+
SHERPA_ONNX_LOGE("Wrong feat dim %d. Expect: %d", feat_dim,
171+
static_cast<int32_t>(mean_.size()));
172+
SHERPA_ONNX_EXIT(-1);
173+
}
174+
175+
using RowMajorMat =
176+
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
177+
Eigen::Map<RowMajorMat> x(features, num_frames, feat_dim);
178+
179+
Eigen::Map<const Eigen::RowVectorXf> mean(mean_.data(), feat_dim);
180+
Eigen::Map<const Eigen::RowVectorXf> inv_std(inv_stddev_.data(), feat_dim);
181+
x.array() =
182+
(x.array().rowwise() - mean.array()).rowwise() * inv_std.array();
183+
}
184+
185+
private:
186+
void Init(void *model_data, size_t model_data_length) {
187+
InitContext(model_data, model_data_length, config_.debug, &handle_);
188+
189+
InitInputOutputAttrs(handle_, config_.debug, &io_info_);
190+
191+
PrepareIO(io_info_, &io_data_, config_.debug);
192+
193+
if (!io_info_ || io_info_->nInputSize != 2 || !io_info_->pInputs) {
194+
SHERPA_ONNX_LOGE("Axera FireRedASR CTC model expects 2 input tensors.");
195+
SHERPA_ONNX_EXIT(-1);
196+
}
197+
198+
if (!io_info_->pOutputs || io_info_->nOutputSize != 2) {
199+
SHERPA_ONNX_LOGE(
200+
"Axera FireRedASR CTC model expects 2 output tensors.");
201+
SHERPA_ONNX_EXIT(-1);
202+
}
203+
204+
if (io_info_->pOutputs[0].nShapeSize < 3) {
205+
SHERPA_ONNX_LOGE(
206+
"The first output tensor rank is too small (nShapeSize = %u)",
207+
io_info_->pOutputs[0].nShapeSize);
208+
SHERPA_ONNX_EXIT(-1);
209+
}
210+
211+
subsampling_factor_ = 4;
212+
vocab_size_ = io_info_->pOutputs[0].pShape[io_info_->pOutputs[0].nShapeSize -
213+
1];
214+
215+
if (config_.debug) {
216+
#if __OHOS__
217+
SHERPA_ONNX_LOGE("subsampling_factor: %{public}d", subsampling_factor_);
218+
SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size_);
219+
#else
220+
SHERPA_ONNX_LOGE("subsampling_factor: %d", subsampling_factor_);
221+
SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_);
222+
#endif
223+
}
224+
225+
mean_ = {10.498912811279297, 10.948603630065918, 11.889163970947266,
226+
12.634881973266602, 13.397452354431152, 14.010934829711914,
227+
14.450813293457031, 14.649748802185059, 14.791581153869629,
228+
14.72234058380127, 14.802156448364258, 14.86101245880127,
229+
15.077230453491211, 15.26024341583252, 15.328754425048828,
230+
15.397353172302246, 15.395853996276855, 15.34103775024414,
231+
15.4662446975708, 15.271865844726562, 15.108253479003906,
232+
15.295886993408203, 15.07359504699707, 15.177886009216309,
233+
15.0756254196167, 15.154109001159668, 15.051127433776855,
234+
15.130733489990234, 15.090286254882812, 15.099433898925781,
235+
15.128166198730469, 15.123964309692383, 15.144022941589355,
236+
15.198014259338379, 15.251392364501953, 15.329950332641602,
237+
15.4017972946167, 15.45089340209961, 15.500616073608398,
238+
15.435726165771484, 15.51086139678955, 15.44755744934082,
239+
15.510979652404785, 15.491739273071289, 15.538031578063965,
240+
15.608367919921875, 15.694382667541504, 15.762181282043457,
241+
15.821470260620117, 15.901959419250488, 15.907241821289062,
242+
15.925711631774902, 15.952259063720703, 16.000732421875,
243+
16.030330657958984, 16.060592651367188, 16.09003448486328,
244+
16.100107192993164, 16.091808319091797, 16.062585830688477,
245+
16.05771255493164, 15.997002601623535, 15.946383476257324,
246+
15.865278244018555, 15.778145790100098, 15.67629623413086,
247+
15.569791793823242, 15.515979766845703, 15.472077369689941,
248+
15.423379898071289, 15.382068634033203, 15.345854759216309,
249+
15.301891326904297, 15.26984691619873, 15.165450096130371,
250+
15.004508972167969, 14.87544059753418, 14.564188003540039,
251+
14.031693458557129, 13.159259796142578};
252+
inv_stddev_ = {
253+
0.2522108852863312, 0.23741021752357483, 0.23185651004314423,
254+
0.23331022262573242, 0.23203925788402557, 0.22906658053398132,
255+
0.22519451379776, 0.22010253369808197, 0.21958276629447937,
256+
0.22198699414730072, 0.22393390536308289, 0.22370608150959015,
257+
0.22321352362632751, 0.2220749408006668, 0.22118520736694336,
258+
0.22136786580085754, 0.2220366895198822, 0.222808837890625,
259+
0.22362081706523895, 0.224283829331398, 0.22464141249656677,
260+
0.22580783069133759, 0.22700978815555573, 0.22852766513824463,
261+
0.22993983328342438, 0.23110738396644592, 0.23227347433567047,
262+
0.23270530998706818, 0.23330524563789368, 0.23406001925468445,
263+
0.23448589444160461, 0.23556077480316162, 0.23632891476154327,
264+
0.23703691363334656, 0.2377307415008545, 0.23786373436450958,
265+
0.2380155622959137, 0.23858875036239624, 0.23943373560905457,
266+
0.2399062216281891, 0.24094033241271973, 0.24173252284526825,
267+
0.24236661195755005, 0.2430112659931183, 0.24341483414173126,
268+
0.243240088224411, 0.24262498319149017, 0.24218837916851044,
269+
0.24165891110897064, 0.241318941116333, 0.2413933277130127,
270+
0.24139994382858276, 0.241432324051857, 0.24122384190559387,
271+
0.24079066514968872, 0.24032147228717804, 0.24016834795475006,
272+
0.24034327268600464, 0.24069449305534363, 0.24123424291610718,
273+
0.24136029183864594, 0.24150611460208893, 0.24179506301879883,
274+
0.24160170555114746, 0.24221885204315186, 0.24253536760807037,
275+
0.24262426793575287, 0.2428186535835266, 0.24223484098911285,
276+
0.24199971556663513, 0.24160003662109375, 0.24074721336364746,
277+
0.23965489864349365, 0.23850350081920624, 0.2359732687473297,
278+
0.23006057739257812, 0.22904986143112183, 0.22814501821994781,
279+
0.22893856465816498, 0.23093441128730774};
280+
}
281+
282+
private:
283+
std::mutex mutex_;
284+
AxEngineGuard ax_engine_guard_;
285+
286+
OfflineModelConfig config_;
287+
AX_ENGINE_HANDLE handle_ = nullptr;
288+
AX_ENGINE_IO_INFO_T *io_info_ = nullptr;
289+
AX_ENGINE_IO_T io_data_;
290+
Ort::AllocatorWithDefaultOptions allocator_;
291+
292+
int32_t vocab_size_ = 0;
293+
int32_t subsampling_factor_ = 0;
294+
295+
std::vector<float> mean_;
296+
std::vector<float> inv_stddev_;
297+
};
298+
299+
OfflineFireRedAsrCtcModelAxera::OfflineFireRedAsrCtcModelAxera(
300+
const OfflineModelConfig &config)
301+
: impl_(std::make_unique<Impl>(config)) {}
302+
303+
template <typename Manager>
304+
OfflineFireRedAsrCtcModelAxera::OfflineFireRedAsrCtcModelAxera(
305+
Manager *mgr, const OfflineModelConfig &config)
306+
: impl_(std::make_unique<Impl>(mgr, config)) {}
307+
308+
OfflineFireRedAsrCtcModelAxera::~OfflineFireRedAsrCtcModelAxera() = default;
309+
310+
std::vector<Ort::Value> OfflineFireRedAsrCtcModelAxera::Forward(
311+
Ort::Value features, Ort::Value features_length) {
312+
return impl_->Forward(std::move(features), std::move(features_length));
313+
}
314+
315+
int32_t OfflineFireRedAsrCtcModelAxera::VocabSize() const {
316+
return impl_->VocabSize();
317+
}
318+
319+
int32_t OfflineFireRedAsrCtcModelAxera::SubsamplingFactor() const {
320+
return impl_->SubsamplingFactor();
321+
}
322+
323+
OrtAllocator *OfflineFireRedAsrCtcModelAxera::Allocator() const {
324+
return impl_->Allocator();
325+
}
326+
327+
void OfflineFireRedAsrCtcModelAxera::NormalizeFeatures(float *features,
328+
int32_t num_frames,
329+
int32_t feat_dim) const {
330+
return impl_->NormalizeFeatures(features, num_frames, feat_dim);
331+
}
332+
333+
#if __ANDROID_API__ >= 9
334+
template OfflineFireRedAsrCtcModelAxera::OfflineFireRedAsrCtcModelAxera(
335+
AAssetManager *mgr, const OfflineModelConfig &config);
336+
#endif
337+
338+
#if __OHOS__
339+
template OfflineFireRedAsrCtcModelAxera::OfflineFireRedAsrCtcModelAxera(
340+
NativeResourceManager *mgr, const OfflineModelConfig &config);
341+
#endif
342+
343+
} // namespace sherpa_onnx
344+

0 commit comments

Comments
 (0)