Skip to content

Commit 67e348b

Browse files
author
LittleMouse
committed
Add GetOutputTensorDataRaw method to AxclModel and update FireRedASR CTC model to use it
1 parent 25e01ae commit 67e348b

3 files changed

Lines changed: 110 additions & 8 deletions

File tree

sherpa-onnx/csrc/axcl/axcl-model.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,31 @@ class AxclModel::Impl {
208208
return {};
209209
}
210210

211+
std::vector<uint8_t> GetOutputTensorDataRaw(const std::string &name) const {
212+
for (size_t i = 0; i < output_tensor_names_.size(); ++i) {
213+
if (output_tensor_names_[i] == name) {
214+
size_t bytes = output_tensors_[i].Size();
215+
std::vector<uint8_t> out(bytes);
216+
217+
auto ret = axclrtMemcpy(out.data(), output_tensors_[i].Get(), bytes,
218+
AXCL_MEMCPY_DEVICE_TO_HOST);
219+
if (ret != 0) {
220+
SHERPA_ONNX_LOGE(
221+
"Failed to call axclrtMemcpy(). tensor name: '%s', return code: "
222+
"%d",
223+
name.c_str(), static_cast<int32_t>(ret));
224+
return {};
225+
}
226+
227+
return out;
228+
}
229+
}
230+
231+
SHERPA_ONNX_LOGE("Found no tensor with name: '%s'", name.c_str());
232+
233+
return {};
234+
}
235+
211236
bool Run() const {
212237
uint32_t group = 0;
213238
auto ret =
@@ -434,6 +459,11 @@ std::vector<float> AxclModel::GetOutputTensorData(
434459
return impl_->GetOutputTensorData(name);
435460
}
436461

462+
std::vector<uint8_t> AxclModel::GetOutputTensorDataRaw(
463+
const std::string &name) const {
464+
return impl_->GetOutputTensorDataRaw(name);
465+
}
466+
437467
bool AxclModel::Run() const { return impl_->Run(); }
438468

439469
bool AxclModel::IsInitialized() const { return impl_->IsInitialized(); }

sherpa-onnx/csrc/axcl/axcl-model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class AxclModel {
3636

3737
std::vector<float> GetOutputTensorData(const std::string &name) const;
3838

39+
std::vector<uint8_t> GetOutputTensorDataRaw(const std::string &name) const;
40+
3941
bool Run() const;
4042
bool IsInitialized() const;
4143

sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.h"
66

77
#include <algorithm>
8+
#include <cmath>
9+
#include <cstring>
810
#include <memory>
911
#include <string>
1012
#include <utility>
@@ -64,28 +66,47 @@ class OfflineFireRedAsrCtcModelAxcl::Impl {
6466
int32_t expected_frames = expected_shape[1];
6567

6668
int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames);
69+
valid_frames = std::min<int32_t>(valid_frames,
70+
static_cast<int32_t>(p_features_length[0]));
6771
std::vector<float> padded_features(expected_frames * feat_dim, 0.0f);
6872
std::copy(p_features, p_features + valid_frames * feat_dim,
6973
padded_features.begin());
7074

7175
std::vector<int32_t> speech_length = {valid_frames};
7276

73-
model_->SetInputTensorData(model_->InputTensorNames()[0],
74-
padded_features.data(), padded_features.size());
75-
model_->SetInputTensorData(model_->InputTensorNames()[1],
76-
speech_length.data(), speech_length.size());
77+
if (!model_->SetInputTensorData(model_->InputTensorNames()[0],
78+
padded_features.data(),
79+
padded_features.size()) ||
80+
!model_->SetInputTensorData(model_->InputTensorNames()[1],
81+
speech_length.data(),
82+
speech_length.size())) {
83+
SHERPA_ONNX_LOGE("Failed to set input tensors for axcl FireRedASR CTC");
84+
SHERPA_ONNX_EXIT(-1);
85+
}
7786

78-
model_->Run();
87+
if (!model_->Run()) {
88+
SHERPA_ONNX_LOGE("Failed to run axcl FireRedASR CTC model");
89+
SHERPA_ONNX_EXIT(-1);
90+
}
7991

8092
auto out_logits =
8193
model_->GetOutputTensorData(model_->OutputTensorNames()[0]);
82-
auto out_lengths =
83-
model_->GetOutputTensorData(model_->OutputTensorNames()[1]);
94+
auto out_lengths_raw =
95+
model_->GetOutputTensorDataRaw(model_->OutputTensorNames()[1]);
8496

8597
auto out_shape = model_->TensorShape(model_->OutputTensorNames()[0]);
8698
int32_t out_frames = out_shape[1];
8799
int32_t vocab_size = out_shape[2];
88100

101+
if (static_cast<int32_t>(out_logits.size()) != out_frames * vocab_size) {
102+
SHERPA_ONNX_LOGE(
103+
"Unexpected logits size from axcl FireRedASR CTC. Got %d values. "
104+
"Expected %d x %d = %d",
105+
static_cast<int32_t>(out_logits.size()), out_frames, vocab_size,
106+
out_frames * vocab_size);
107+
SHERPA_ONNX_EXIT(-1);
108+
}
109+
89110
std::array<int64_t, 3> logits_shape = {1, out_frames, vocab_size};
90111
Ort::Value logits = Ort::Value::CreateTensor<float>(
91112
allocator_, logits_shape.data(), logits_shape.size());
@@ -96,7 +117,7 @@ class OfflineFireRedAsrCtcModelAxcl::Impl {
96117
Ort::Value lengths = Ort::Value::CreateTensor<int64_t>(
97118
allocator_, lengths_shape.data(), lengths_shape.size());
98119
int64_t *p_lengths = lengths.GetTensorMutableData<int64_t>();
99-
p_lengths[0] = static_cast<int64_t>(out_lengths[0]);
120+
p_lengths[0] = ReadOutputLength(out_lengths_raw, out_frames);
100121

101122
std::vector<Ort::Value> ans;
102123
ans.push_back(std::move(logits));
@@ -131,6 +152,55 @@ class OfflineFireRedAsrCtcModelAxcl::Impl {
131152
}
132153

133154
private:
155+
static int64_t ReadOutputLength(const std::vector<uint8_t> &raw,
156+
int32_t max_frames) {
157+
if (raw.empty()) {
158+
SHERPA_ONNX_LOGE(
159+
"The output-length tensor from axcl FireRedASR CTC is empty");
160+
SHERPA_ONNX_EXIT(-1);
161+
}
162+
163+
auto in_range = [max_frames](int64_t value) {
164+
return 0 <= value && value <= max_frames;
165+
};
166+
167+
if (raw.size() == sizeof(int64_t)) {
168+
int64_t value = 0;
169+
std::memcpy(&value, raw.data(), sizeof(value));
170+
if (in_range(value)) {
171+
return value;
172+
}
173+
}
174+
175+
if (raw.size() >= sizeof(int32_t)) {
176+
int32_t value_i32 = 0;
177+
std::memcpy(&value_i32, raw.data(), sizeof(value_i32));
178+
if (in_range(value_i32)) {
179+
return value_i32;
180+
}
181+
182+
uint32_t value_u32 = 0;
183+
std::memcpy(&value_u32, raw.data(), sizeof(value_u32));
184+
if (value_u32 <= static_cast<uint32_t>(max_frames)) {
185+
return value_u32;
186+
}
187+
188+
float value_f32 = 0;
189+
std::memcpy(&value_f32, raw.data(), sizeof(value_f32));
190+
if (std::isfinite(value_f32) && value_f32 >= 0 &&
191+
value_f32 <= static_cast<float>(max_frames)) {
192+
return static_cast<int64_t>(value_f32);
193+
}
194+
}
195+
196+
SHERPA_ONNX_LOGE(
197+
"Failed to interpret the output-length tensor from axcl FireRedASR "
198+
"CTC. Byte size: %d. Max frames: %d",
199+
static_cast<int32_t>(raw.size()), max_frames);
200+
SHERPA_ONNX_EXIT(-1);
201+
return 0;
202+
}
203+
134204
void Init() {
135205
if (!model_->IsInitialized()) {
136206
SHERPA_ONNX_LOGE("Failed to initialize the model with '%s'",

0 commit comments

Comments
 (0)