FireRedASR2 supported axcl backend#3272
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request expands the capabilities of the Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds an AXCL-backed OfflineFireRedAsrCtcModel implementation, exposes AxclModel raw-output retrieval, and integrates the AXCL model into factories and recognizer creation paths with CMake and conditional compilation updates. Changes
Sequence Diagram(s)sequenceDiagram
participant Factory
participant OfflineModelAxcl
participant AxclModel
participant HostMemory
Factory->>OfflineModelAxcl: Create(config) / Create(mgr, config)
OfflineModelAxcl->>AxclModel: Load model (file or buffer) and Init()
OfflineModelAxcl->>OfflineModelAxcl: NormalizeFeatures(features)
OfflineModelAxcl->>AxclModel: SetInputTensors(padded_features, speech_length)
OfflineModelAxcl->>AxclModel: Run()
AxclModel-->>OfflineModelAxcl: GetOutputTensorDataRaw(logits_name)
AxclModel-->>OfflineModelAxcl: GetOutputTensorDataRaw(lengths_name)
OfflineModelAxcl->>HostMemory: Parse outputs -> create Ort::Value tensors
OfflineModelAxcl-->>Factory: Return logits_tensor, lengths_tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc`:
- Around line 47-71: The Forward method is incorrectly using features_shape[1]
(padded width) as the valid frame count and silently truncating long utterances;
instead read the authoritative valid-frame count from the features_length tensor
(use p_features_length[0]) and use that as valid_frames, then check against the
model window (expected_frames) and reject or handle oversized utterances rather
than silently clipping. Update Forward (function Forward, variables
p_features_length, expected_frames, padded_features) to set valid_frames =
static_cast<int32_t>(p_features_length[0]), validate that valid_frames <=
expected_frames and if not log an error and exit (or implement explicit chunking
logic) before copying/padding; otherwise proceed to copy valid_frames worth of
data into padded_features.
- Around line 73-99: After calling model_->Run() and retrieving outputs via
model_->GetOutputTensorData(model_->OutputTensorNames()[0]) and [1], validate
that out_logits and out_lengths are non-empty and that
TensorShape(model_->OutputTensorNames()[0]) yields expected dimensions before
creating Ort::Value tensors and using std::copy or indexing out_lengths[0]; if
any check fails, return or throw an error (or log and abort) so you don't
dereference empty vectors or copy zero elements into p_logits or read
p_lengths[0]. Ensure checks reference the existing symbols: SetInputTensorData,
Run, GetOutputTensorData, out_logits, out_lengths,
TensorShape(model_->OutputTensorNames()[0]), std::copy, and p_lengths[0].
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.h`:
- Around line 19-38: The class OfflineFireRedAsrCtcModelAxcl must explicitly
disable batch processing to avoid multi-stream calls hitting Forward() with
batch_size != 1; add an override of SupportBatchProcessing() in the
OfflineFireRedAsrCtcModelAxcl class that returns false (i.e., bool
SupportBatchProcessing() const override { return false; }) so the runtime will
not route batched inputs to this backend; update the class declaration to
include this method override alongside the existing methods (e.g., Forward,
VocabSize, SubsamplingFactor, Allocator, NormalizeFeatures).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f27920d6-2061-4d1d-973c-9b497eb5dc4a
📒 Files selected for processing (5)
sherpa-onnx/csrc/CMakeLists.txtsherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.ccsherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.hsherpa-onnx/csrc/offline-ctc-model.ccsherpa-onnx/csrc/offline-recognizer-impl.cc
| std::vector<Ort::Value> Forward(Ort::Value features, | ||
| Ort::Value features_length) { | ||
| auto features_shape = features.GetTensorTypeAndShapeInfo().GetShape(); | ||
| int32_t batch_size = features_shape[0]; | ||
| int32_t num_frames = features_shape[1]; | ||
| int32_t feat_dim = features_shape[2]; | ||
|
|
||
| const float *p_features = features.GetTensorData<float>(); | ||
| const int64_t *p_features_length = features_length.GetTensorData<int64_t>(); | ||
|
|
||
| if (batch_size != 1) { | ||
| SHERPA_ONNX_LOGE("Only batch size 1 is supported by axcl. Given: %d", | ||
| batch_size); | ||
| SHERPA_ONNX_EXIT(-1); | ||
| } | ||
|
|
||
| auto expected_shape = model_->TensorShape(model_->InputTensorNames()[0]); | ||
| int32_t expected_frames = expected_shape[1]; | ||
|
|
||
| int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames); | ||
| std::vector<float> padded_features(expected_frames * feat_dim, 0.0f); | ||
| std::copy(p_features, p_features + valid_frames * feat_dim, | ||
| padded_features.begin()); | ||
|
|
||
| std::vector<int32_t> speech_length = {valid_frames}; |
There was a problem hiding this comment.
Use features_length here and don't silently clip long utterances.
features_shape[1] is the padded tensor width, not the authoritative valid-frame count. With std::min(num_frames, expected_frames), any utterance longer than the AXCL window is truncated and the tail audio is lost instead of being rejected or chunked.
Suggested fix
- int32_t num_frames = features_shape[1];
int32_t feat_dim = features_shape[2];
@@
- int32_t valid_frames = std::min<int32_t>(num_frames, expected_frames);
+ int32_t valid_frames = static_cast<int32_t>(p_features_length[0]);
+ if (valid_frames > expected_frames) {
+ SHERPA_ONNX_LOGE(
+ "Input has %d valid frames, but the AXCL FireRedASR CTC model only "
+ "accepts %d frames.",
+ valid_frames, expected_frames);
+ SHERPA_ONNX_EXIT(-1);
+ }
std::vector<float> padded_features(expected_frames * feat_dim, 0.0f);
std::copy(p_features, p_features + valid_frames * feat_dim,
padded_features.begin());🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.cc` around lines 47
- 71, The Forward method is incorrectly using features_shape[1] (padded width)
as the valid frame count and silently truncating long utterances; instead read
the authoritative valid-frame count from the features_length tensor (use
p_features_length[0]) and use that as valid_frames, then check against the model
window (expected_frames) and reject or handle oversized utterances rather than
silently clipping. Update Forward (function Forward, variables
p_features_length, expected_frames, padded_features) to set valid_frames =
static_cast<int32_t>(p_features_length[0]), validate that valid_frames <=
expected_frames and if not log an error and exit (or implement explicit chunking
logic) before copying/padding; otherwise proceed to copy valid_frames worth of
data into padded_features.
| class OfflineFireRedAsrCtcModelAxcl : public OfflineCtcModel { | ||
| public: | ||
| explicit OfflineFireRedAsrCtcModelAxcl(const OfflineModelConfig &config); | ||
|
|
||
| template <typename Manager> | ||
| OfflineFireRedAsrCtcModelAxcl(Manager *mgr, const OfflineModelConfig &config); | ||
|
|
||
| ~OfflineFireRedAsrCtcModelAxcl() override; | ||
|
|
||
| std::vector<Ort::Value> Forward(Ort::Value features, | ||
| Ort::Value features_length) override; | ||
|
|
||
| int32_t VocabSize() const override; | ||
|
|
||
| int32_t SubsamplingFactor() const override; | ||
|
|
||
| OrtAllocator *Allocator() const override; | ||
|
|
||
| void NormalizeFeatures(float *features, int32_t num_frames, | ||
| int32_t feat_dim) const override; |
There was a problem hiding this comment.
Declare this AXCL model as non-batchable.
OfflineCtcModel defaults SupportBatchProcessing() to true, but this backend hard-fails in Forward() when batch_size != 1. That mismatch can route multi-stream decode into a path that is guaranteed to abort.
Suggested fix
class OfflineFireRedAsrCtcModelAxcl : public OfflineCtcModel {
public:
explicit OfflineFireRedAsrCtcModelAxcl(const OfflineModelConfig &config);
@@
int32_t VocabSize() const override;
int32_t SubsamplingFactor() const override;
+
+ bool SupportBatchProcessing() const override { return false; }
OrtAllocator *Allocator() const override;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class OfflineFireRedAsrCtcModelAxcl : public OfflineCtcModel { | |
| public: | |
| explicit OfflineFireRedAsrCtcModelAxcl(const OfflineModelConfig &config); | |
| template <typename Manager> | |
| OfflineFireRedAsrCtcModelAxcl(Manager *mgr, const OfflineModelConfig &config); | |
| ~OfflineFireRedAsrCtcModelAxcl() override; | |
| std::vector<Ort::Value> Forward(Ort::Value features, | |
| Ort::Value features_length) override; | |
| int32_t VocabSize() const override; | |
| int32_t SubsamplingFactor() const override; | |
| OrtAllocator *Allocator() const override; | |
| void NormalizeFeatures(float *features, int32_t num_frames, | |
| int32_t feat_dim) const override; | |
| class OfflineFireRedAsrCtcModelAxcl : public OfflineCtcModel { | |
| public: | |
| explicit OfflineFireRedAsrCtcModelAxcl(const OfflineModelConfig &config); | |
| template <typename Manager> | |
| OfflineFireRedAsrCtcModelAxcl(Manager *mgr, const OfflineModelConfig &config); | |
| ~OfflineFireRedAsrCtcModelAxcl() override; | |
| std::vector<Ort::Value> Forward(Ort::Value features, | |
| Ort::Value features_length) override; | |
| int32_t VocabSize() const override; | |
| int32_t SubsamplingFactor() const override; | |
| bool SupportBatchProcessing() const override { return false; } | |
| OrtAllocator *Allocator() const override; | |
| void NormalizeFeatures(float *features, int32_t num_frames, | |
| int32_t feat_dim) const override; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/axcl/offline-fire-red-asr-ctc-model-axcl.h` around lines 19
- 38, The class OfflineFireRedAsrCtcModelAxcl must explicitly disable batch
processing to avoid multi-stream calls hitting Forward() with batch_size != 1;
add an override of SupportBatchProcessing() in the OfflineFireRedAsrCtcModelAxcl
class that returns false (i.e., bool SupportBatchProcessing() const override {
return false; }) so the runtime will not route batched inputs to this backend;
update the class declaration to include this method override alongside the
existing methods (e.g., Forward, VocabSize, SubsamplingFactor, Allocator,
NormalizeFeatures).
There was a problem hiding this comment.
Code Review
This pull request adds support for the FireRedASR2 CTC model with the AXCL backend. The changes include a new implementation for the model (offline-fire-red-asr-ctc-model-axcl.cc) and integrating it into the build system and model factory functions.
My review focuses on the new implementation. I've identified a few areas for improvement regarding performance, code clarity, and maintainability, such as optimizing the initialization of constant data and removing unused code. Overall, the changes look good and are consistent with the goal of the pull request.
Note: Security Review did not run due to the size of the PR.
| int32_t feat_dim = features_shape[2]; | ||
|
|
||
| const float *p_features = features.GetTensorData<float>(); | ||
| const int64_t *p_features_length = features_length.GetTensorData<int64_t>(); |
| ans.push_back(std::move(logits)); | ||
| ans.push_back(std::move(lengths)); |
| void NormalizeFeatures(float *features, int32_t num_frames, | ||
| int32_t feat_dim) const { | ||
| if (static_cast<int32_t>(mean_.size()) != feat_dim) { | ||
| SHERPA_ONNX_LOGE("Bad things happened"); |
| mean_ = {10.498912811279297, 10.948603630065918, 11.889163970947266, | ||
| 12.634881973266602, 13.397452354431152, 14.010934829711914, | ||
| 14.450813293457031, 14.649748802185059, 14.791581153869629, | ||
| 14.72234058380127, 14.802156448364258, 14.86101245880127, | ||
| 15.077230453491211, 15.26024341583252, 15.328754425048828, | ||
| 15.397353172302246, 15.395853996276855, 15.34103775024414, | ||
| 15.4662446975708, 15.271865844726562, 15.108253479003906, | ||
| 15.295886993408203, 15.07359504699707, 15.177886009216309, | ||
| 15.0756254196167, 15.154109001159668, 15.051127433776855, | ||
| 15.130733489990234, 15.090286254882812, 15.099433898925781, | ||
| 15.128166198730469, 15.123964309692383, 15.144022941589355, | ||
| 15.198014259338379, 15.251392364501953, 15.329950332641602, | ||
| 15.4017972946167, 15.45089340209961, 15.500616073608398, | ||
| 15.435726165771484, 15.51086139678955, 15.44755744934082, | ||
| 15.510979652404785, 15.491739273071289, 15.538031578063965, | ||
| 15.608367919921875, 15.694382667541504, 15.762181282043457, | ||
| 15.821470260620117, 15.901959419250488, 15.907241821289062, | ||
| 15.925711631774902, 15.952259063720703, 16.000732421875, | ||
| 16.030330657958984, 16.060592651367188, 16.09003448486328, | ||
| 16.100107192993164, 16.091808319091797, 16.062585830688477, | ||
| 16.05771255493164, 15.997002601623535, 15.946383476257324, | ||
| 15.865278244018555, 15.778145790100098, 15.67629623413086, | ||
| 15.569791793823242, 15.515979766845703, 15.472077369689941, | ||
| 15.423379898071289, 15.382068634033203, 15.345854759216309, | ||
| 15.301891326904297, 15.26984691619873, 15.165450096130371, | ||
| 15.004508972167969, 14.87544059753418, 14.564188003540039, | ||
| 14.031693458557129, 13.159259796142578}; | ||
| inv_stddev_ = { | ||
| 0.2522108852863312, 0.23741021752357483, 0.23185651004314423, | ||
| 0.23331022262573242, 0.23203925788402557, 0.22906658053398132, | ||
| 0.22519451379776, 0.22010253369808197, 0.21958276629447937, | ||
| 0.22198699414730072, 0.22393390536308289, 0.22370608150959015, | ||
| 0.22321352362632751, 0.2220749408006668, 0.22118520736694336, | ||
| 0.22136786580085754, 0.2220366895198822, 0.222808837890625, | ||
| 0.22362081706523895, 0.224283829331398, 0.22464141249656677, | ||
| 0.22580783069133759, 0.22700978815555573, 0.22852766513824463, | ||
| 0.22993983328342438, 0.23110738396644592, 0.23227347433567047, | ||
| 0.23270530998706818, 0.23330524563789368, 0.23406001925468445, | ||
| 0.23448589444160461, 0.23556077480316162, 0.23632891476154327, | ||
| 0.23703691363334656, 0.2377307415008545, 0.23786373436450958, | ||
| 0.2380155622959137, 0.23858875036239624, 0.23943373560905457, | ||
| 0.2399062216281891, 0.24094033241271973, 0.24173252284526825, | ||
| 0.24236661195755005, 0.2430112659931183, 0.24341483414173126, | ||
| 0.243240088224411, 0.24262498319149017, 0.24218837916851044, | ||
| 0.24165891110897064, 0.241318941116333, 0.2413933277130127, | ||
| 0.24139994382858276, 0.241432324051857, 0.24122384190559387, | ||
| 0.24079066514968872, 0.24032147228717804, 0.24016834795475006, | ||
| 0.24034327268600464, 0.24069449305534363, 0.24123424291610718, | ||
| 0.24136029183864594, 0.24150611460208893, 0.24179506301879883, | ||
| 0.24160170555114746, 0.24221885204315186, 0.24253536760807037, | ||
| 0.24262426793575287, 0.2428186535835266, 0.24223484098911285, | ||
| 0.24199971556663513, 0.24160003662109375, 0.24074721336364746, | ||
| 0.23965489864349365, 0.23850350081920624, 0.2359732687473297, | ||
| 0.23006057739257812, 0.22904986143112183, 0.22814501821994781, | ||
| 0.22893856465816498, 0.23093441128730774}; |
There was a problem hiding this comment.
These large hardcoded vectors for mean_ and inv_stddev_ are inefficiently initialized on every object creation and make the code hard to read.
Consider defining them as static const std::array at file scope and assigning them to the member vectors in Init(). This will improve both performance and maintainability.
For example:
// At file scope
namespace {
static const std::array<float, 80> kFireRedAsrCtcMean = {{
// ... values
}};
} // namespace
// In Init()
mean_.assign(kFireRedAsrCtcMean.begin(), kFireRedAsrCtcMean.end());A similar change should be applied to inv_stddev_.
…CTC model to use it
Summary by CodeRabbit