FireRedASR2 supported axera backend#3273
Conversation
📝 WalkthroughWalkthroughAdds two new offline Fire Red ASR CTC model backend implementations for AXERA and AXCL providers. Includes header/implementation pairs for each backend with model loading, feature normalization, and inference execution. Updates CMakeLists.txt to register new sources and modifies factory logic to enable provider-specific model selection. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant OfflineRecognizer
participant Factory as OfflineCtcModel::Create
participant Selection as Provider Selector
participant AxeraModel as OfflineFireRedAsrCtcModelAxera
participant AxclModel as OfflineFireRedAsrCtcModelAxcl
participant Fallback as OfflineFireRedAsrCtcModel
Client->>OfflineRecognizer: Create with fire_red_asr_ctc config
OfflineRecognizer->>Factory: Create(config, model_config)
Factory->>Selection: Check provider & model type
alt provider == "axera"
Selection->>AxeraModel: new OfflineFireRedAsrCtcModelAxera(config)
AxeraModel->>AxeraModel: Initialize AX engine, load model
AxeraModel-->>Factory: return instance
else provider == "axcl"
Selection->>AxclModel: new OfflineFireRedAsrCtcModelAxcl(config)
AxclModel->>AxclModel: Initialize AXCL model, precompute stats
AxclModel-->>Factory: return instance
else default
Selection->>Fallback: new OfflineFireRedAsrCtcModel(config)
Fallback-->>Factory: return instance
end
Factory-->>OfflineRecognizer: return model instance
OfflineRecognizer-->>Client: ready for inference
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the capabilities of the Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for FireRedASR CTC models for the axera and axcl backends. The changes include adding new model implementation files for both backends, updating CMakeLists.txt to include them in the build, and modifying the factory functions in offline-ctc-model.cc and offline-recognizer-impl.cc to create instances of these new models.
My review has identified a few areas for improvement:
- There is a potential bug in the
axclimplementation where the actual feature length is not being used, which could lead to incorrect results. - The
axclimplementation may not be thread-safe, unlike itsaxeracounterpart. - There is significant code duplication between the
axeraandaxclimplementations, particularly for feature normalization logic and constants. I've suggested refactoring this into a common utility to improve maintainability, which aligns with the repository's guidelines.
| auto expected_shape = model_->TensorShape(model_->InputTensorNames()[0]); | ||
| int32_t expected_frames = expected_shape[1]; | ||
|
|
||
| int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames); |
There was a problem hiding this comment.
The valid_frames calculation doesn't seem to take the actual feature length (features_length) into account. The p_features_length variable is initialized from features_length but is never used. This could lead to processing padded data as if it were real features, producing incorrect results.
You should probably use p_features_length to determine the number of valid frames, similar to the axera implementation.
int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames);
valid_frames = std::min<int32_t>(valid_frames,
static_cast<int32_t>(p_features_length[0]));| std::vector<Ort::Value> Forward(Ort::Value features, | ||
| Ort::Value features_length) { |
There was a problem hiding this comment.
The Forward method in the axera implementation (offline-fire-red-asr-ctc-model-axera.cc) uses a std::lock_guard to ensure thread safety. This implementation does not have a mutex. If a single OfflineFireRedAsrCtcModelAxcl instance is intended to be used by multiple threads concurrently, this Forward method is not thread-safe, as it modifies the state of the underlying model_ via SetInputTensorData and Run. Please consider adding a mutex to protect the Forward method if concurrent access is possible.
| class OfflineFireRedAsrCtcModelAxcl::Impl { | ||
| public: | ||
| explicit Impl(const OfflineModelConfig &config) | ||
| : config_(config), allocator_{} { | ||
| model_ = std::make_unique<AxclModel>(config_.fire_red_asr_ctc.model); | ||
|
|
||
| Init(); | ||
| } | ||
|
|
||
| template <typename Manager> | ||
| Impl(Manager *mgr, const OfflineModelConfig &config) | ||
| : config_(config), allocator_{} { | ||
| auto buf = ReadFile(mgr, config_.fire_red_asr_ctc.model); | ||
| model_ = std::make_unique<AxclModel>(buf.data(), buf.size()); | ||
|
|
||
| Init(); | ||
| } | ||
|
|
||
| std::vector<Ort::Value> Forward(Ort::Value features, | ||
| Ort::Value features_length) { | ||
| auto features_shape = features.GetTensorTypeAndShapeInfo().GetShape(); | ||
| int32_t batch_size = features_shape[0]; | ||
| int32_t num_frames = features_shape[1]; | ||
| int32_t feat_dim = features_shape[2]; | ||
|
|
||
| const float *p_features = features.GetTensorData<float>(); | ||
| const int64_t *p_features_length = features_length.GetTensorData<int64_t>(); | ||
|
|
||
| if (batch_size != 1) { | ||
| SHERPA_ONNX_LOGE("Only batch size 1 is supported by axcl. Given: %d", | ||
| batch_size); | ||
| SHERPA_ONNX_EXIT(-1); | ||
| } | ||
|
|
||
| auto expected_shape = model_->TensorShape(model_->InputTensorNames()[0]); | ||
| int32_t expected_frames = expected_shape[1]; | ||
|
|
||
| int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames); | ||
| std::vector<float> padded_features(expected_frames * feat_dim, 0.0f); | ||
| std::copy(p_features, p_features + valid_frames * feat_dim, | ||
| padded_features.begin()); | ||
|
|
||
| std::vector<int32_t> speech_length = {valid_frames}; | ||
|
|
||
| model_->SetInputTensorData(model_->InputTensorNames()[0], | ||
| padded_features.data(), padded_features.size()); | ||
| model_->SetInputTensorData(model_->InputTensorNames()[1], | ||
| speech_length.data(), speech_length.size()); | ||
|
|
||
| model_->Run(); | ||
|
|
||
| auto out_logits = | ||
| model_->GetOutputTensorData(model_->OutputTensorNames()[0]); | ||
| auto out_lengths = | ||
| model_->GetOutputTensorData(model_->OutputTensorNames()[1]); | ||
|
|
||
| auto out_shape = model_->TensorShape(model_->OutputTensorNames()[0]); | ||
| int32_t out_frames = out_shape[1]; | ||
| int32_t vocab_size = out_shape[2]; | ||
|
|
||
| std::array<int64_t, 3> logits_shape = {1, out_frames, vocab_size}; | ||
| Ort::Value logits = Ort::Value::CreateTensor<float>( | ||
| allocator_, logits_shape.data(), logits_shape.size()); | ||
| float *p_logits = logits.GetTensorMutableData<float>(); | ||
| std::copy(out_logits.begin(), out_logits.end(), p_logits); | ||
|
|
||
| std::array<int64_t, 1> lengths_shape = {1}; | ||
| Ort::Value lengths = Ort::Value::CreateTensor<int64_t>( | ||
| allocator_, lengths_shape.data(), lengths_shape.size()); | ||
| int64_t *p_lengths = lengths.GetTensorMutableData<int64_t>(); | ||
| p_lengths[0] = static_cast<int64_t>(out_lengths[0]); | ||
|
|
||
| std::vector<Ort::Value> ans; | ||
| ans.push_back(std::move(logits)); | ||
| ans.push_back(std::move(lengths)); | ||
|
|
||
| return ans; | ||
| } | ||
|
|
||
| int32_t VocabSize() const { return vocab_size_; } | ||
|
|
||
| int32_t SubsamplingFactor() const { return subsampling_factor_; } | ||
|
|
||
| OrtAllocator *Allocator() { return allocator_; } | ||
|
|
||
| void NormalizeFeatures(float *features, int32_t num_frames, | ||
| int32_t feat_dim) const { | ||
| if (static_cast<int32_t>(mean_.size()) != feat_dim) { | ||
| SHERPA_ONNX_LOGE("Bad things happened"); | ||
| SHERPA_ONNX_LOGE("Wrong feat dim %d. Expect: %d", feat_dim, | ||
| static_cast<int32_t>(mean_.size())); | ||
| SHERPA_ONNX_EXIT(-1); | ||
| } | ||
|
|
||
| using RowMajorMat = | ||
| Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; | ||
| Eigen::Map<RowMajorMat> x(features, num_frames, feat_dim); | ||
|
|
||
| Eigen::Map<const Eigen::RowVectorXf> mean(mean_.data(), feat_dim); | ||
| Eigen::Map<const Eigen::RowVectorXf> inv_std(inv_stddev_.data(), feat_dim); | ||
| x.array() = | ||
| (x.array().rowwise() - mean.array()).rowwise() * inv_std.array(); | ||
| } | ||
|
|
||
| private: | ||
| void Init() { | ||
| if (!model_->IsInitialized()) { | ||
| SHERPA_ONNX_LOGE("Failed to initialize the model with '%s'", | ||
| config_.fire_red_asr_ctc.model.c_str()); | ||
| SHERPA_ONNX_EXIT(-1); | ||
| } | ||
|
|
||
| subsampling_factor_ = 4; | ||
|
|
||
| auto shape = model_->TensorShape(model_->OutputTensorNames()[0]); | ||
| vocab_size_ = shape.back(); | ||
|
|
||
| if (config_.debug) { | ||
| #if __OHOS__ | ||
| SHERPA_ONNX_LOGE("subsampling_factor: %{public}d", subsampling_factor_); | ||
| SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size_); | ||
| #else | ||
| SHERPA_ONNX_LOGE("subsampling_factor: %d", subsampling_factor_); | ||
| SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_); | ||
| #endif | ||
| } | ||
|
|
||
| mean_ = {10.498912811279297, 10.948603630065918, 11.889163970947266, | ||
| 12.634881973266602, 13.397452354431152, 14.010934829711914, | ||
| 14.450813293457031, 14.649748802185059, 14.791581153869629, | ||
| 14.72234058380127, 14.802156448364258, 14.86101245880127, | ||
| 15.077230453491211, 15.26024341583252, 15.328754425048828, | ||
| 15.397353172302246, 15.395853996276855, 15.34103775024414, | ||
| 15.4662446975708, 15.271865844726562, 15.108253479003906, | ||
| 15.295886993408203, 15.07359504699707, 15.177886009216309, | ||
| 15.0756254196167, 15.154109001159668, 15.051127433776855, | ||
| 15.130733489990234, 15.090286254882812, 15.099433898925781, | ||
| 15.128166198730469, 15.123964309692383, 15.144022941589355, | ||
| 15.198014259338379, 15.251392364501953, 15.329950332641602, | ||
| 15.4017972946167, 15.45089340209961, 15.500616073608398, | ||
| 15.435726165771484, 15.51086139678955, 15.44755744934082, | ||
| 15.510979652404785, 15.491739273071289, 15.538031578063965, | ||
| 15.608367919921875, 15.694382667541504, 15.762181282043457, | ||
| 15.821470260620117, 15.901959419250488, 15.907241821289062, | ||
| 15.925711631774902, 15.952259063720703, 16.000732421875, | ||
| 16.030330657958984, 16.060592651367188, 16.09003448486328, | ||
| 16.100107192993164, 16.091808319091797, 16.062585830688477, | ||
| 16.05771255493164, 15.997002601623535, 15.946383476257324, | ||
| 15.865278244018555, 15.778145790100098, 15.67629623413086, | ||
| 15.569791793823242, 15.515979766845703, 15.472077369689941, | ||
| 15.423379898071289, 15.382068634033203, 15.345854759216309, | ||
| 15.301891326904297, 15.26984691619873, 15.165450096130371, | ||
| 15.004508972167969, 14.87544059753418, 14.564188003540039, | ||
| 14.031693458557129, 13.159259796142578}; | ||
| inv_stddev_ = { | ||
| 0.2522108852863312, 0.23741021752357483, 0.23185651004314423, | ||
| 0.23331022262573242, 0.23203925788402557, 0.22906658053398132, | ||
| 0.22519451379776, 0.22010253369808197, 0.21958276629447937, | ||
| 0.22198699414730072, 0.22393390536308289, 0.22370608150959015, | ||
| 0.22321352362632751, 0.2220749408006668, 0.22118520736694336, | ||
| 0.22136786580085754, 0.2220366895198822, 0.222808837890625, | ||
| 0.22362081706523895, 0.224283829331398, 0.22464141249656677, | ||
| 0.22580783069133759, 0.22700978815555573, 0.22852766513824463, | ||
| 0.22993983328342438, 0.23110738396644592, 0.23227347433567047, | ||
| 0.23270530998706818, 0.23330524563789368, 0.23406001925468445, | ||
| 0.23448589444160461, 0.23556077480316162, 0.23632891476154327, | ||
| 0.23703691363334656, 0.2377307415008545, 0.23786373436450958, | ||
| 0.2380155622959137, 0.23858875036239624, 0.23943373560905457, | ||
| 0.2399062216281891, 0.24094033241271973, 0.24173252284526825, | ||
| 0.24236661195755005, 0.2430112659931183, 0.24341483414173126, | ||
| 0.243240088224411, 0.24262498319149017, 0.24218837916851044, | ||
| 0.24165891110897064, 0.241318941116333, 0.2413933277130127, | ||
| 0.24139994382858276, 0.241432324051857, 0.24122384190559387, | ||
| 0.24079066514968872, 0.24032147228717804, 0.24016834795475006, | ||
| 0.24034327268600464, 0.24069449305534363, 0.24123424291610718, | ||
| 0.24136029183864594, 0.24150611460208893, 0.24179506301879883, | ||
| 0.24160170555114746, 0.24221885204315186, 0.24253536760807037, | ||
| 0.24262426793575287, 0.2428186535835266, 0.24223484098911285, | ||
| 0.24199971556663513, 0.24160003662109375, 0.24074721336364746, | ||
| 0.23965489864349365, 0.23850350081920624, 0.2359732687473297, | ||
| 0.23006057739257812, 0.22904986143112183, 0.22814501821994781, | ||
| 0.22893856465816498, 0.23093441128730774}; | ||
| } | ||
|
|
||
| private: | ||
| OfflineModelConfig config_; | ||
| std::unique_ptr<AxclModel> model_; | ||
| Ort::AllocatorWithDefaultOptions allocator_; | ||
|
|
||
| int32_t vocab_size_ = 0; | ||
| int32_t subsampling_factor_ = 0; | ||
|
|
||
| std::vector<float> mean_; | ||
| std::vector<float> inv_stddev_; | ||
| }; |
There was a problem hiding this comment.
There is significant code duplication between this file and sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.cc. Specifically, the NormalizeFeatures method and the hardcoded mean_ and inv_stddev_ vectors are identical.
To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, please consider refactoring this duplicated logic into a common utility.
For example, you could create a new header file (e.g., offline-fire-red-asr-ctc-model-common.h) to define the constant mean_ and inv_stddev_ vectors and a shared NormalizeFeatures function. Both axcl and axera implementations could then use this common utility, reducing code duplication.
This is in line with the repository's general rule to move duplicated utility functions to a common file.
References
- Move duplicated utility functions, such as
Trim, to a common utility file (e.g.,text-utils.handtext-utils.cc) for reuse across the codebase.
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (6)
sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc (2)
112-112: Const correctness mismatch.Same issue as AXERA:
Impl::Allocator()is non-const but called from the const wrapper method.Proposed fix
- OrtAllocator *Allocator() { return allocator_; } + OrtAllocator *Allocator() const { return allocator_; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc` at line 112, The Impl::Allocator() accessor is non-const but is invoked from a const wrapper; make its signature const-correct by changing or adding a const-qualified overload so it can be called from const methods — e.g., add or modify OrtAllocator *Allocator() const (or provide a const OrtAllocator* Allocator() const and keep the non-const version) to return allocator_ and satisfy const callers (update any declarations in the class definition for Impl to match).
156-210: Duplicate normalization parameters.These
mean_andinv_stddev_arrays are identical to the AXERA implementation. As noted for the AXERA file, consider extracting to a shared location.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc` around lines 156 - 210, mean_ and inv_stddev_ are duplicated from the AXERA implementation; extract these arrays into a single shared constant and reference them instead of duplicating. Create a shared header (e.g., declare constexpr/std::array or static const std::vector named kMean and kInvInvStddev) and a matching source if needed, then replace the inline initializers of mean_ and inv_stddev_ in offline-fire-red-asr-ctc-model-axcl.cc (and the AXERA file) to copy or reference the shared constants (e.g., mean_ = shared::kMean; inv_stddev_ = shared::kInvStddev;), updating includes to pull in the new header.sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.cc (2)
164-164: Const correctness mismatch.
Impl::Allocator()is non-const, but it's called fromOfflineFireRedAsrCtcModelAxera::Allocator() const(line 323-324). This compiles becauseimpl_is a pointer (the pointer itself doesn't change), but for consistency with the const interface, consider makingImpl::Allocator()const.Proposed fix
- OrtAllocator *Allocator() { return allocator_; } + OrtAllocator *Allocator() const { return allocator_; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.cc` at line 164, Impl::Allocator() is non-const but is invoked by OfflineFireRedAsrCtcModelAxera::Allocator() const; make Impl::Allocator() a const method to match the const interface. Update the Impl class declaration and its definition to change OrtAllocator *Allocator() to OrtAllocator *Allocator() const (and adjust any const_casts/usages if present) so the const member function can call it without casting.
225-279: Hardcoded normalization parameters duplicated across backends.The
mean_andinv_stddev_arrays (80 elements each) are identical to those in the AXCL implementation (sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cclines 156-210). Consider extracting these constants to a shared header or utility to avoid duplication and ensure they stay in sync.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.cc` around lines 225 - 279, The mean_ and inv_stddev_ hardcoded arrays are duplicated; extract them into a single shared constant (e.g., a header or small utility) and reference that constant from both backends instead of redefining the arrays. Create a shared symbol name (for example kFireRedMean and kFireRedInvStddev or similar) with the same type/size (80 floats) in a new header, include that header in both files, and replace the local mean_ and inv_stddev_ definitions in offline-fire-red-asr-ctc-model-axera.cc (and the AXCL file) with assignments or references to the shared symbols so both implementations use the single source of truth.sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.h (1)
8-11: Minor: Unused includes.Same as the AXERA header -
<string>and<utility>appear unused.Proposed fix
`#include` <memory> -#include <string> -#include <utility> `#include` <vector>🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.h` around lines 8 - 11, The header currently includes <string> and <utility> which are unused; remove those two includes from the top of offline-fire-red-asr-ctc-model-axcl.h so only required headers (<memory>, <vector>, etc.) remain; ensure no compilation errors by confirming symbols in this header (constructors, class declarations, or function prototypes referencing std::string or std::pair) actually don't rely on those headers before deletion.sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.h (1)
8-11: Minor: Unused includes.
<string>and<utility>appear unused in this header. Consider removing them if they're not needed for the class interface.Proposed fix
`#include` <memory> -#include <string> -#include <utility> `#include` <vector>🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.h` around lines 8 - 11, The header currently includes unused headers <string> and <utility>; remove those two includes from the top of the file (leave <memory>, <vector> intact) unless the class/interface in this header (e.g., any declarations related to the OfflineFireRedAsrCtcModelAxera type) actually requires std::string or std::move/etc.; if those symbols are needed by function signatures in this header, keep them or replace with forward declarations; otherwise delete the unused includes to clean up dependencies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc`:
- Around line 66-71: valid_frames currently uses std::min<int32_t>(num_frames,
expected_frames) but may still exceed the actual available frames in p_features
(p_features_length[0]); change the calculation to clamp against
p_features_length[0] as well (e.g., min of num_frames, expected_frames, and
p_features_length[0]) and then use that clamped valid_frames for the std::copy
into padded_features and for initializing speech_length so we never copy past
the provided p_features buffer.
- Around line 47-106: The Forward method is not thread-safe; add a class-level
std::mutex (e.g., mutex_) and include <mutex>, then acquire a lock at the start
of Forward (use std::lock_guard<std::mutex> lock(mutex_)) to serialize access to
model_ and any shared resources (allocator_, model_->SetInputTensorData,
model_->Run, model_->GetOutputTensorData, etc.); ensure the mutex is declared in
the same class that defines Forward so concurrent calls to Forward are
protected.
In `@sherpa-onnx/csrc/offline-ctc-model.cc`:
- Around line 221-230: The templated Create (Create<T>) has the same
silent-fallback: when config.provider == "axera" or "axcl" but the corresponding
build flags (SHERPA_ONNX_ENABLE_AXERA / SHERPA_ONNX_ENABLE_AXCL) are not set it
currently falls through to the default model; update the templated Create
implementation to mirror the fix applied to the non-templated Create: add `#else`
branches for each provider case that emit a clear error (e.g., throw
std::runtime_error or call the logger) stating that the requested provider
(config.provider) is not available in this build, and only return the
provider-specific instance inside the `#if` blocks
(std::make_unique<OfflineFireRedAsrCtcModelAxera> /
std::make_unique<OfflineFireRedAsrCtcModelAxcl>); this ensures Create<T> does
not silently return the default OfflineFireRedAsrCtcModel when a provider was
explicitly requested.
- Around line 142-151: The branch handling config.provider currently falls back
silently to OfflineFireRedAsrCtcModel when
SHERPA_ONNX_ENABLE_AXERA/SHERPA_ONNX_ENABLE_AXCL are not defined; change the
conditional blocks for the "axera" and "axcl" cases so that if the compile-time
flag is not enabled you emit a clear runtime error or warning and abort (e.g.,
throw std::runtime_error or LOG(FATAL)) instead of returning the CPU model;
specifically update the provider check that would return
OfflineFireRedAsrCtcModelAxera/OfflineFireRedAsrCtcModelAxcl to include an `#else`
that reports that config.provider == "axera" (or "axcl") is unavailable because
SHERPA_ONNX_ENABLE_AXERA/SHERPA_ONNX_ENABLE_AXCL is not defined, and only fall
back to OfflineFireRedAsrCtcModel for an explicit unknown/default provider case.
---
Nitpick comments:
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc`:
- Line 112: The Impl::Allocator() accessor is non-const but is invoked from a
const wrapper; make its signature const-correct by changing or adding a
const-qualified overload so it can be called from const methods — e.g., add or
modify OrtAllocator *Allocator() const (or provide a const OrtAllocator*
Allocator() const and keep the non-const version) to return allocator_ and
satisfy const callers (update any declarations in the class definition for Impl
to match).
- Around line 156-210: mean_ and inv_stddev_ are duplicated from the AXERA
implementation; extract these arrays into a single shared constant and reference
them instead of duplicating. Create a shared header (e.g., declare
constexpr/std::array or static const std::vector named kMean and kInvInvStddev)
and a matching source if needed, then replace the inline initializers of mean_
and inv_stddev_ in offline-fire-red-asr-ctc-model-axcl.cc (and the AXERA file)
to copy or reference the shared constants (e.g., mean_ = shared::kMean;
inv_stddev_ = shared::kInvStddev;), updating includes to pull in the new header.
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.h`:
- Around line 8-11: The header currently includes <string> and <utility> which
are unused; remove those two includes from the top of
offline-fire-red-asr-ctc-model-axcl.h so only required headers (<memory>,
<vector>, etc.) remain; ensure no compilation errors by confirming symbols in
this header (constructors, class declarations, or function prototypes
referencing std::string or std::pair) actually don't rely on those headers
before deletion.
In `@sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.cc`:
- Line 164: Impl::Allocator() is non-const but is invoked by
OfflineFireRedAsrCtcModelAxera::Allocator() const; make Impl::Allocator() a
const method to match the const interface. Update the Impl class declaration and
its definition to change OrtAllocator *Allocator() to OrtAllocator *Allocator()
const (and adjust any const_casts/usages if present) so the const member
function can call it without casting.
- Around line 225-279: The mean_ and inv_stddev_ hardcoded arrays are
duplicated; extract them into a single shared constant (e.g., a header or small
utility) and reference that constant from both backends instead of redefining
the arrays. Create a shared symbol name (for example kFireRedMean and
kFireRedInvStddev or similar) with the same type/size (80 floats) in a new
header, include that header in both files, and replace the local mean_ and
inv_stddev_ definitions in offline-fire-red-asr-ctc-model-axera.cc (and the AXCL
file) with assignments or references to the shared symbols so both
implementations use the single source of truth.
In `@sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.h`:
- Around line 8-11: The header currently includes unused headers <string> and
<utility>; remove those two includes from the top of the file (leave <memory>,
<vector> intact) unless the class/interface in this header (e.g., any
declarations related to the OfflineFireRedAsrCtcModelAxera type) actually
requires std::string or std::move/etc.; if those symbols are needed by function
signatures in this header, keep them or replace with forward declarations;
otherwise delete the unused includes to clean up dependencies.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 797afc21-73a7-47fe-ab17-d4a8ad59f2f2
📒 Files selected for processing (7)
sherpa-onnx/csrc/CMakeLists.txtsherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.ccsherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.hsherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.ccsherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.hsherpa-onnx/csrc/offline-ctc-model.ccsherpa-onnx/csrc/offline-recognizer-impl.cc
| std::vector<Ort::Value> Forward(Ort::Value features, | ||
| Ort::Value features_length) { | ||
| auto features_shape = features.GetTensorTypeAndShapeInfo().GetShape(); | ||
| int32_t batch_size = features_shape[0]; | ||
| int32_t num_frames = features_shape[1]; | ||
| int32_t feat_dim = features_shape[2]; | ||
|
|
||
| const float *p_features = features.GetTensorData<float>(); | ||
| const int64_t *p_features_length = features_length.GetTensorData<int64_t>(); | ||
|
|
||
| if (batch_size != 1) { | ||
| SHERPA_ONNX_LOGE("Only batch size 1 is supported by axcl. Given: %d", | ||
| batch_size); | ||
| SHERPA_ONNX_EXIT(-1); | ||
| } | ||
|
|
||
| auto expected_shape = model_->TensorShape(model_->InputTensorNames()[0]); | ||
| int32_t expected_frames = expected_shape[1]; | ||
|
|
||
| int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames); | ||
| std::vector<float> padded_features(expected_frames * feat_dim, 0.0f); | ||
| std::copy(p_features, p_features + valid_frames * feat_dim, | ||
| padded_features.begin()); | ||
|
|
||
| std::vector<int32_t> speech_length = {valid_frames}; | ||
|
|
||
| model_->SetInputTensorData(model_->InputTensorNames()[0], | ||
| padded_features.data(), padded_features.size()); | ||
| model_->SetInputTensorData(model_->InputTensorNames()[1], | ||
| speech_length.data(), speech_length.size()); | ||
|
|
||
| model_->Run(); | ||
|
|
||
| auto out_logits = | ||
| model_->GetOutputTensorData(model_->OutputTensorNames()[0]); | ||
| auto out_lengths = | ||
| model_->GetOutputTensorData(model_->OutputTensorNames()[1]); | ||
|
|
||
| auto out_shape = model_->TensorShape(model_->OutputTensorNames()[0]); | ||
| int32_t out_frames = out_shape[1]; | ||
| int32_t vocab_size = out_shape[2]; | ||
|
|
||
| std::array<int64_t, 3> logits_shape = {1, out_frames, vocab_size}; | ||
| Ort::Value logits = Ort::Value::CreateTensor<float>( | ||
| allocator_, logits_shape.data(), logits_shape.size()); | ||
| float *p_logits = logits.GetTensorMutableData<float>(); | ||
| std::copy(out_logits.begin(), out_logits.end(), p_logits); | ||
|
|
||
| std::array<int64_t, 1> lengths_shape = {1}; | ||
| Ort::Value lengths = Ort::Value::CreateTensor<int64_t>( | ||
| allocator_, lengths_shape.data(), lengths_shape.size()); | ||
| int64_t *p_lengths = lengths.GetTensorMutableData<int64_t>(); | ||
| p_lengths[0] = static_cast<int64_t>(out_lengths[0]); | ||
|
|
||
| std::vector<Ort::Value> ans; | ||
| ans.push_back(std::move(logits)); | ||
| ans.push_back(std::move(lengths)); | ||
|
|
||
| return ans; | ||
| } |
There was a problem hiding this comment.
Missing thread synchronization in Forward method.
The AXERA implementation (sherpa-onnx/csrc/axera/offline-fire-red-asr-ctc-model-axera.cc) uses a mutex to protect the Forward method, but this AXCL implementation has no synchronization. If AxclModel is not internally thread-safe (and hardware inference engines typically aren't), concurrent calls to Forward could cause data races or undefined behavior.
Consider adding mutex protection for consistency with the AXERA implementation:
Proposed fix
Add a mutex member and lock in Forward:
private:
OfflineModelConfig config_;
std::unique_ptr<AxclModel> model_;
Ort::AllocatorWithDefaultOptions allocator_;
+ std::mutex mutex_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 0;And in Forward:
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
+ std::lock_guard<std::mutex> lock(mutex_);
+
auto features_shape = features.GetTensorTypeAndShapeInfo().GetShape();Don't forget to add #include <mutex> at the top.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc` around lines 47
- 106, The Forward method is not thread-safe; add a class-level std::mutex
(e.g., mutex_) and include <mutex>, then acquire a lock at the start of Forward
(use std::lock_guard<std::mutex> lock(mutex_)) to serialize access to model_ and
any shared resources (allocator_, model_->SetInputTensorData, model_->Run,
model_->GetOutputTensorData, etc.); ensure the mutex is declared in the same
class that defines Forward so concurrent calls to Forward are protected.
| int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames); | ||
| std::vector<float> padded_features(expected_frames * feat_dim, 0.0f); | ||
| std::copy(p_features, p_features + valid_frames * feat_dim, | ||
| padded_features.begin()); | ||
|
|
||
| std::vector<int32_t> speech_length = {valid_frames}; |
There was a problem hiding this comment.
Verify valid_frames doesn't exceed p_features_length bounds.
The valid_frames calculation at line 66 clamps to expected_frames, but doesn't validate against p_features_length[0]. If p_features_length[0] is less than num_frames, you may be copying more data than intended. The AXERA implementation has the same pattern but explicitly incorporates p_features_length[0] in the calculation.
Consider clamping to the actual valid length:
Proposed fix
int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames);
+ valid_frames = std::min<int32_t>(valid_frames,
+ static_cast<int32_t>(p_features_length[0]));
std::vector<float> padded_features(expected_frames * feat_dim, 0.0f);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames); | |
| std::vector<float> padded_features(expected_frames * feat_dim, 0.0f); | |
| std::copy(p_features, p_features + valid_frames * feat_dim, | |
| padded_features.begin()); | |
| std::vector<int32_t> speech_length = {valid_frames}; | |
| int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames); | |
| valid_frames = std::min<int32_t>(valid_frames, | |
| static_cast<int32_t>(p_features_length[0])); | |
| std::vector<float> padded_features(expected_frames * feat_dim, 0.0f); | |
| std::copy(p_features, p_features + valid_frames * feat_dim, | |
| padded_features.begin()); | |
| std::vector<int32_t> speech_length = {valid_frames}; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc` around lines 66
- 71, valid_frames currently uses std::min<int32_t>(num_frames, expected_frames)
but may still exceed the actual available frames in p_features
(p_features_length[0]); change the calculation to clamp against
p_features_length[0] as well (e.g., min of num_frames, expected_frames, and
p_features_length[0]) and then use that clamped valid_frames for the std::copy
into padded_features and for initializing speech_length so we never copy past
the provided p_features buffer.
| if (config.provider == "axera") { | ||
| #if SHERPA_ONNX_ENABLE_AXERA | ||
| return std::make_unique<OfflineFireRedAsrCtcModelAxera>(config); | ||
| #endif | ||
| } else if (config.provider == "axcl") { | ||
| #if SHERPA_ONNX_ENABLE_AXCL | ||
| return std::make_unique<OfflineFireRedAsrCtcModelAxcl>(config); | ||
| #endif | ||
| } | ||
| return std::make_unique<OfflineFireRedAsrCtcModel>(config); |
There was a problem hiding this comment.
Silent fallback when AXERA/AXCL not enabled.
When the user specifies provider == "axera" but SHERPA_ONNX_ENABLE_AXERA is not defined, the code silently falls through to create the default OfflineFireRedAsrCtcModel. This could lead to unexpected behavior where the user believes they're using AXERA but actually gets the CPU implementation.
Consider adding an error or warning when the requested provider isn't available:
Proposed fix
} else if (!config.fire_red_asr_ctc.model.empty()) {
if (config.provider == "axera") {
`#if` SHERPA_ONNX_ENABLE_AXERA
return std::make_unique<OfflineFireRedAsrCtcModelAxera>(config);
+#else
+ SHERPA_ONNX_LOGE(
+ "Provider 'axera' requested but SHERPA_ONNX_ENABLE_AXERA is not "
+ "enabled. Falling back to default implementation.");
`#endif`
} else if (config.provider == "axcl") {
`#if` SHERPA_ONNX_ENABLE_AXCL
return std::make_unique<OfflineFireRedAsrCtcModelAxcl>(config);
+#else
+ SHERPA_ONNX_LOGE(
+ "Provider 'axcl' requested but SHERPA_ONNX_ENABLE_AXCL is not "
+ "enabled. Falling back to default implementation.");
`#endif`
}
return std::make_unique<OfflineFireRedAsrCtcModel>(config);
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/offline-ctc-model.cc` around lines 142 - 151, The branch
handling config.provider currently falls back silently to
OfflineFireRedAsrCtcModel when SHERPA_ONNX_ENABLE_AXERA/SHERPA_ONNX_ENABLE_AXCL
are not defined; change the conditional blocks for the "axera" and "axcl" cases
so that if the compile-time flag is not enabled you emit a clear runtime error
or warning and abort (e.g., throw std::runtime_error or LOG(FATAL)) instead of
returning the CPU model; specifically update the provider check that would
return OfflineFireRedAsrCtcModelAxera/OfflineFireRedAsrCtcModelAxcl to include
an `#else` that reports that config.provider == "axera" (or "axcl") is unavailable
because SHERPA_ONNX_ENABLE_AXERA/SHERPA_ONNX_ENABLE_AXCL is not defined, and
only fall back to OfflineFireRedAsrCtcModel for an explicit unknown/default
provider case.
| if (config.provider == "axera") { | ||
| #if SHERPA_ONNX_ENABLE_AXERA | ||
| return std::make_unique<OfflineFireRedAsrCtcModelAxera>(mgr, config); | ||
| #endif | ||
| } else if (config.provider == "axcl") { | ||
| #if SHERPA_ONNX_ENABLE_AXCL | ||
| return std::make_unique<OfflineFireRedAsrCtcModelAxcl>(mgr, config); | ||
| #endif | ||
| } | ||
| return std::make_unique<OfflineFireRedAsrCtcModel>(mgr, config); |
There was a problem hiding this comment.
Same silent fallback issue in templated Create.
The templated version has the same silent fallback behavior. If you address the issue above, apply the same fix here for consistency.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/offline-ctc-model.cc` around lines 221 - 230, The templated
Create (Create<T>) has the same silent-fallback: when config.provider == "axera"
or "axcl" but the corresponding build flags (SHERPA_ONNX_ENABLE_AXERA /
SHERPA_ONNX_ENABLE_AXCL) are not set it currently falls through to the default
model; update the templated Create implementation to mirror the fix applied to
the non-templated Create: add `#else` branches for each provider case that emit a
clear error (e.g., throw std::runtime_error or call the logger) stating that the
requested provider (config.provider) is not available in this build, and only
return the provider-specific instance inside the `#if` blocks
(std::make_unique<OfflineFireRedAsrCtcModelAxera> /
std::make_unique<OfflineFireRedAsrCtcModelAxcl>); this ensures Create<T> does
not silently return the default OfflineFireRedAsrCtcModel when a provider was
explicitly requested.
Summary by CodeRabbit