55#include " sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.h"
66
77#include < algorithm>
8+ #include < cmath>
9+ #include < cstring>
810#include < memory>
911#include < string>
1012#include < utility>
@@ -64,28 +66,47 @@ class OfflineFireRedAsrCtcModelAxcl::Impl {
6466 int32_t expected_frames = expected_shape[1 ];
6567
6668 int32_t valid_frames = std::min<int32_t >(num_frames, expected_frames);
69+ valid_frames = std::min<int32_t >(valid_frames,
70+ static_cast <int32_t >(p_features_length[0 ]));
6771 std::vector<float > padded_features (expected_frames * feat_dim, 0 .0f );
6872 std::copy (p_features, p_features + valid_frames * feat_dim,
6973 padded_features.begin ());
7074
7175 std::vector<int32_t > speech_length = {valid_frames};
7276
73- model_->SetInputTensorData (model_->InputTensorNames ()[0 ],
74- padded_features.data (), padded_features.size ());
75- model_->SetInputTensorData (model_->InputTensorNames ()[1 ],
76- speech_length.data (), speech_length.size ());
77+ if (!model_->SetInputTensorData (model_->InputTensorNames ()[0 ],
78+ padded_features.data (),
79+ padded_features.size ()) ||
80+ !model_->SetInputTensorData (model_->InputTensorNames ()[1 ],
81+ speech_length.data (),
82+ speech_length.size ())) {
83+ SHERPA_ONNX_LOGE (" Failed to set input tensors for axcl FireRedASR CTC" );
84+ SHERPA_ONNX_EXIT (-1 );
85+ }
7786
78- model_->Run ();
87+ if (!model_->Run ()) {
88+ SHERPA_ONNX_LOGE (" Failed to run axcl FireRedASR CTC model" );
89+ SHERPA_ONNX_EXIT (-1 );
90+ }
7991
8092 auto out_logits =
8193 model_->GetOutputTensorData (model_->OutputTensorNames ()[0 ]);
82- auto out_lengths =
83- model_->GetOutputTensorData (model_->OutputTensorNames ()[1 ]);
94+ auto out_lengths_raw =
95+ model_->GetOutputTensorDataRaw (model_->OutputTensorNames ()[1 ]);
8496
8597 auto out_shape = model_->TensorShape (model_->OutputTensorNames ()[0 ]);
8698 int32_t out_frames = out_shape[1 ];
8799 int32_t vocab_size = out_shape[2 ];
88100
101+ if (static_cast <int32_t >(out_logits.size ()) != out_frames * vocab_size) {
102+ SHERPA_ONNX_LOGE (
103+ " Unexpected logits size from axcl FireRedASR CTC. Got %d values. "
104+ " Expected %d x %d = %d" ,
105+ static_cast <int32_t >(out_logits.size ()), out_frames, vocab_size,
106+ out_frames * vocab_size);
107+ SHERPA_ONNX_EXIT (-1 );
108+ }
109+
89110 std::array<int64_t , 3 > logits_shape = {1 , out_frames, vocab_size};
90111 Ort::Value logits = Ort::Value::CreateTensor<float >(
91112 allocator_, logits_shape.data (), logits_shape.size ());
@@ -96,7 +117,7 @@ class OfflineFireRedAsrCtcModelAxcl::Impl {
96117 Ort::Value lengths = Ort::Value::CreateTensor<int64_t >(
97118 allocator_, lengths_shape.data (), lengths_shape.size ());
98119 int64_t *p_lengths = lengths.GetTensorMutableData <int64_t >();
99- p_lengths[0 ] = static_cast < int64_t >(out_lengths[ 0 ] );
120+ p_lengths[0 ] = ReadOutputLength (out_lengths_raw, out_frames );
100121
101122 std::vector<Ort::Value> ans;
102123 ans.push_back (std::move (logits));
@@ -131,6 +152,55 @@ class OfflineFireRedAsrCtcModelAxcl::Impl {
131152 }
132153
133154 private:
155+ static int64_t ReadOutputLength (const std::vector<uint8_t > &raw,
156+ int32_t max_frames) {
157+ if (raw.empty ()) {
158+ SHERPA_ONNX_LOGE (
159+ " The output-length tensor from axcl FireRedASR CTC is empty" );
160+ SHERPA_ONNX_EXIT (-1 );
161+ }
162+
163+ auto in_range = [max_frames](int64_t value) {
164+ return 0 <= value && value <= max_frames;
165+ };
166+
167+ if (raw.size () == sizeof (int64_t )) {
168+ int64_t value = 0 ;
169+ std::memcpy (&value, raw.data (), sizeof (value));
170+ if (in_range (value)) {
171+ return value;
172+ }
173+ }
174+
175+ if (raw.size () >= sizeof (int32_t )) {
176+ int32_t value_i32 = 0 ;
177+ std::memcpy (&value_i32, raw.data (), sizeof (value_i32));
178+ if (in_range (value_i32)) {
179+ return value_i32;
180+ }
181+
182+ uint32_t value_u32 = 0 ;
183+ std::memcpy (&value_u32, raw.data (), sizeof (value_u32));
184+ if (value_u32 <= static_cast <uint32_t >(max_frames)) {
185+ return value_u32;
186+ }
187+
188+ float value_f32 = 0 ;
189+ std::memcpy (&value_f32, raw.data (), sizeof (value_f32));
190+ if (std::isfinite (value_f32) && value_f32 >= 0 &&
191+ value_f32 <= static_cast <float >(max_frames)) {
192+ return static_cast <int64_t >(value_f32);
193+ }
194+ }
195+
196+ SHERPA_ONNX_LOGE (
197+ " Failed to interpret the output-length tensor from axcl FireRedASR "
198+ " CTC. Byte size: %d. Max frames: %d" ,
199+ static_cast <int32_t >(raw.size ()), max_frames);
200+ SHERPA_ONNX_EXIT (-1 );
201+ return 0 ;
202+ }
203+
134204 void Init () {
135205 if (!model_->IsInitialized ()) {
136206 SHERPA_ONNX_LOGE (" Failed to initialize the model with '%s'" ,
0 commit comments