[WIP] add CTC prefix beam search / hotwords #1439
[WIP] add CTC prefix beam search / hotwords #1439pkufool wants to merge 8 commits intok2-fsa:masterfrom
Conversation
|
请问这个request计划什么时候合并 |
|
Hello. I trained CR-CTC model and decoded streaming CTC model and got token repetition (ex. ref: 안녕하세요 / hyp: 안녕녕하세요) So, I really need online prefix beam search.... Do you have any plans to release online ctc prefix beam search? Thank you! |
af07894 to
a2b64a2
Compare
📝 WalkthroughWalkthroughAdds offline and online CTC prefix-beam-search decoders, extends decoder interfaces to accept optional stream contexts, adds CTC-aware hypothesis scoring (blank/non-blank probs), and wires hotword / ContextGraph support into offline and online recognizers and example CLIs. Changes
Sequence DiagramsequenceDiagram
participant Client
participant Recognizer as Recognizer\n(Offline/Online)
participant Decoder as CTCPrefixBeamDecoder
participant StepWorker
participant ContextGraph
participant Hypotheses
Client->>Recognizer: decode(log_probs, log_probs_length, hotwords?)
Recognizer->>Decoder: Decode(..., ss[], n)
loop per time frame (t)
Decoder->>StepWorker: top-k candidates from frame t
loop per hypothesis
StepWorker->>Hypotheses: evaluate blank/repeat/extend cases
alt ContextGraph present
StepWorker->>ContextGraph: ForwardOneStep(token, ctx_state)
ContextGraph-->>StepWorker: lm_logprob, new_ctx_state
StepWorker->>Hypotheses: add/update hypothesis (update lm and ctx_state)
else
StepWorker->>Hypotheses: add/update hypothesis (CTC probs)
end
end
Decoder->>Hypotheses: prune/select active paths (max_active_paths)
end
Decoder->>Hypotheses: GetMostProbable(use_ctc=true)
Hypotheses-->>Decoder: best hypothesis
Decoder-->>Recognizer: results
Recognizer-->>Client: return results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
sherpa-onnx/csrc/offline-recognizer.cc (1)
84-89:⚠️ Potential issue | 🟡 MinorUpdate the hotwords validation error message to match the new allowed method.
Line [87]-Line [89] still says only
modified_beam_searchis valid, but Line [84]-Line [85] now also allowsprefix_beam_search. This will mislead users during config errors.Proposed fix
if (!hotwords_file.empty() && (decoding_method != "modified_beam_search" && decoding_method != "prefix_beam_search")) { SHERPA_ONNX_LOGE( - "Please use --decoding-method=modified_beam_search if you" + "Please use --decoding-method=modified_beam_search or " + "--decoding-method=prefix_beam_search if you" " provide --hotwords-file. Given --decoding-method='%s'", decoding_method.c_str()); return false; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/offline-recognizer.cc` around lines 84 - 89, The error message for hotwords validation is out of sync with the allowed decoding methods: update the SHERPA_ONNX_LOGE message in offline-recognizer.cc (the block that checks hotwords_file and decoding_method) to mention both allowed methods ("modified_beam_search" and "prefix_beam_search") and include the actual provided decoding_method via decoding_method.c_str() as it already does; adjust the wording so it instructs users to use either --decoding-method=modified_beam_search or --decoding-method=prefix_beam_search when supplying --hotwords-file.
🧹 Nitpick comments (3)
sherpa-onnx/csrc/hypothesis.h (1)
53-57: Minor inconsistency: Usingfloatinfinity fordoublevariable.
log_prob_nbis declared asdoublebut initialized withstd::numeric_limits<float>::infinity(). While this works (float infinity converts to double infinity), usingstd::numeric_limits<double>::infinity()would be more consistent with the type declaration.♻️ Suggested fix for type consistency
// The total score of ys which ends with non blank token in log space - double log_prob_nb = -std::numeric_limits<float>::infinity(); + double log_prob_nb = -std::numeric_limits<double>::infinity();🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/hypothesis.h` around lines 53 - 57, The variable log_prob_nb in hypothesis.h is declared as double but initialized using std::numeric_limits<float>::infinity(); change the initializer to use std::numeric_limits<double>::infinity() to match the declared type (locate the log_prob_nb declaration and replace the float infinity with double infinity).sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc (1)
113-117: Consider removing or gating the commented debug code.Debug logging code is commented out. Consider removing it before merging or wrapping it in a debug preprocessor macro if needed for future development.
🧹 Suggested cleanup
cur[b] = StepWorker(p_log_probs, cur[b], blank_id_, vocab_size, max_active_paths_, context_graphs[b].get()); - // for (auto &x : cur[b]) { - // SHERPA_ONNX_LOGE("step : %d, key : %s, ac : %f, lm : %f", t, - // x.Key().c_str(), x.LogProb(true), x.lm_log_prob); - // } - // SHERPA_ONNX_LOGE("\n");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc` around lines 113 - 117, Remove or gate the commented-out debug logging in offline-ctc-prefix-beam-search-decoder.cc: either delete the block that iterates over cur[b] and calls SHERPA_ONNX_LOGE for Key(), LogProb(true) and lm_log_prob, or wrap it with a compile-time macro or runtime log-level check (e.g., DEBUG_BEAM_SEARCH or use the existing logging verbosity) so the snippet around cur[b], t, x.Key(), x.LogProb(true), and x.lm_log_prob is only included when debug logging is enabled.sherpa-onnx/csrc/offline-recognizer-ctc-impl.h (1)
399-417: UseSHERPA_ONNX_EXIT(-1)for consistency with the rest of the codebase.Line 406 uses
exit(-1)directly, while line 230 usesSHERPA_ONNX_EXIT(-1). The macro is used consistently elsewhere for controlled exit handling.♻️ Suggested fix
if (!is) { SHERPA_ONNX_LOGE("Open hotwords file failed: %s", config_.hotwords_file.c_str()); - exit(-1); + SHERPA_ONNX_EXIT(-1); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/csrc/offline-recognizer-ctc-impl.h` around lines 399 - 417, In InitHotwords(), replace the direct call to exit(-1) with the project exit macro for consistency: call SHERPA_ONNX_EXIT(-1) instead of exit(-1) in the error branch that detects failure to open config_.hotwords_file; update the code around the if (!is) block in InitHotwords to use SHERPA_ONNX_EXIT so it matches other exits (e.g., the one referenced near line 230) and preserves controlled shutdown behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@sherpa-onnx/csrc/offline-recognizer-ctc-impl.h`:
- Around line 419-442: The istringstream validity check in InitHotwords is dead
code; instead verify ReadFile succeeded by checking the returned buffer (auto
buf from ReadFile) before constructing std::istringstream. If buf is empty or
ReadFile indicates failure, log the error using SHERPA_ONNX_LOGE with
config_.hotwords_file and call SHERPA_ONNX_EXIT(-1). Only after confirming buf
contains data, create std::istringstream, then call EncodeHotwords(modeling_unit
via config_.model_config.modeling_unit, symbol_table_, bpe_encoder_.get(),
&hotwords_, &boost_scores_) and proceed to create hotwords_graph_ as currently
done.
---
Outside diff comments:
In `@sherpa-onnx/csrc/offline-recognizer.cc`:
- Around line 84-89: The error message for hotwords validation is out of sync
with the allowed decoding methods: update the SHERPA_ONNX_LOGE message in
offline-recognizer.cc (the block that checks hotwords_file and decoding_method)
to mention both allowed methods ("modified_beam_search" and
"prefix_beam_search") and include the actual provided decoding_method via
decoding_method.c_str() as it already does; adjust the wording so it instructs
users to use either --decoding-method=modified_beam_search or
--decoding-method=prefix_beam_search when supplying --hotwords-file.
---
Nitpick comments:
In `@sherpa-onnx/csrc/hypothesis.h`:
- Around line 53-57: The variable log_prob_nb in hypothesis.h is declared as
double but initialized using std::numeric_limits<float>::infinity(); change the
initializer to use std::numeric_limits<double>::infinity() to match the declared
type (locate the log_prob_nb declaration and replace the float infinity with
double infinity).
In `@sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc`:
- Around line 113-117: Remove or gate the commented-out debug logging in
offline-ctc-prefix-beam-search-decoder.cc: either delete the block that iterates
over cur[b] and calls SHERPA_ONNX_LOGE for Key(), LogProb(true) and lm_log_prob,
or wrap it with a compile-time macro or runtime log-level check (e.g.,
DEBUG_BEAM_SEARCH or use the existing logging verbosity) so the snippet around
cur[b], t, x.Key(), x.LogProb(true), and x.lm_log_prob is only included when
debug logging is enabled.
In `@sherpa-onnx/csrc/offline-recognizer-ctc-impl.h`:
- Around line 399-417: In InitHotwords(), replace the direct call to exit(-1)
with the project exit macro for consistency: call SHERPA_ONNX_EXIT(-1) instead
of exit(-1) in the error branch that detects failure to open
config_.hotwords_file; update the code around the if (!is) block in InitHotwords
to use SHERPA_ONNX_EXIT so it matches other exits (e.g., the one referenced near
line 230) and preserves controlled shutdown behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bfba1fc1-640a-4181-b394-42ab0c71a974
📒 Files selected for processing (12)
sherpa-onnx/csrc/CMakeLists.txtsherpa-onnx/csrc/hypothesis.ccsherpa-onnx/csrc/hypothesis.hsherpa-onnx/csrc/offline-ctc-decoder.hsherpa-onnx/csrc/offline-ctc-fst-decoder.ccsherpa-onnx/csrc/offline-ctc-fst-decoder.hsherpa-onnx/csrc/offline-ctc-greedy-search-decoder.ccsherpa-onnx/csrc/offline-ctc-greedy-search-decoder.hsherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.ccsherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.hsherpa-onnx/csrc/offline-recognizer-ctc-impl.hsherpa-onnx/csrc/offline-recognizer.cc
| #if __ANDROID_API__ >= 9 | ||
| void InitHotwords(AAssetManager *mgr) { | ||
| // each line in hotwords_file contains space-separated words | ||
|
|
||
| auto buf = ReadFile(mgr, config_.hotwords_file); | ||
|
|
||
| std::istringstream is(std::string(buf.begin(), buf.end())); | ||
|
|
||
| if (!is) { | ||
| SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | ||
| config_.hotwords_file.c_str()); | ||
| exit(-1); | ||
| } | ||
|
|
||
| if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, | ||
| bpe_encoder_.get(), &hotwords_, &boost_scores_)) { | ||
| SHERPA_ONNX_LOGE( | ||
| "Failed to encode some hotwords, skip them already, see logs above " | ||
| "for details."); | ||
| } | ||
| hotwords_graph_ = std::make_shared<ContextGraph>( | ||
| hotwords_, config_.hotwords_score, boost_scores_); | ||
| } | ||
| #endif |
There was a problem hiding this comment.
Dead code: istringstream state check is always true after construction.
The !is check on line 427 is unreachable. An istringstream constructed from a string is always in a valid state. If ReadFile fails to read the file, you should check the return value of ReadFile instead.
Also, use SHERPA_ONNX_EXIT(-1) for consistency.
🐛 Suggested fix
`#if` __ANDROID_API__ >= 9
void InitHotwords(AAssetManager *mgr) {
// each line in hotwords_file contains space-separated words
auto buf = ReadFile(mgr, config_.hotwords_file);
-
- std::istringstream is(std::string(buf.begin(), buf.end()));
-
- if (!is) {
+ if (buf.empty()) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
- exit(-1);
+ SHERPA_ONNX_EXIT(-1);
}
+
+ std::istringstream is(std::string(buf.begin(), buf.end()));📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #if __ANDROID_API__ >= 9 | |
| void InitHotwords(AAssetManager *mgr) { | |
| // each line in hotwords_file contains space-separated words | |
| auto buf = ReadFile(mgr, config_.hotwords_file); | |
| std::istringstream is(std::string(buf.begin(), buf.end())); | |
| if (!is) { | |
| SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | |
| config_.hotwords_file.c_str()); | |
| exit(-1); | |
| } | |
| if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, | |
| bpe_encoder_.get(), &hotwords_, &boost_scores_)) { | |
| SHERPA_ONNX_LOGE( | |
| "Failed to encode some hotwords, skip them already, see logs above " | |
| "for details."); | |
| } | |
| hotwords_graph_ = std::make_shared<ContextGraph>( | |
| hotwords_, config_.hotwords_score, boost_scores_); | |
| } | |
| #endif | |
| `#if` __ANDROID_API__ >= 9 | |
| void InitHotwords(AAssetManager *mgr) { | |
| // each line in hotwords_file contains space-separated words | |
| auto buf = ReadFile(mgr, config_.hotwords_file); | |
| if (buf.empty()) { | |
| SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | |
| config_.hotwords_file.c_str()); | |
| SHERPA_ONNX_EXIT(-1); | |
| } | |
| std::istringstream is(std::string(buf.begin(), buf.end())); | |
| if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, | |
| bpe_encoder_.get(), &hotwords_, &boost_scores_)) { | |
| SHERPA_ONNX_LOGE( | |
| "Failed to encode some hotwords, skip them already, see logs above " | |
| "for details."); | |
| } | |
| hotwords_graph_ = std::make_shared<ContextGraph>( | |
| hotwords_, config_.hotwords_score, boost_scores_); | |
| } | |
| `#endif` |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/offline-recognizer-ctc-impl.h` around lines 419 - 442, The
istringstream validity check in InitHotwords is dead code; instead verify
ReadFile succeeded by checking the returned buffer (auto buf from ReadFile)
before constructing std::istringstream. If buf is empty or ReadFile indicates
failure, log the error using SHERPA_ONNX_LOGE with config_.hotwords_file and
call SHERPA_ONNX_EXIT(-1). Only after confirming buf contains data, create
std::istringstream, then call EncodeHotwords(modeling_unit via
config_.model_config.modeling_unit, symbol_table_, bpe_encoder_.get(),
&hotwords_, &boost_scores_) and proceed to create hotwords_graph_ as currently
done.
Agent-Logs-Url: https://github.com/pkufool/sherpa-onnx/sessions/f9574840-2100-424c-ac19-017ec182a0df Co-authored-by: pkufool <11765074+pkufool@users.noreply.github.com>
Agent-Logs-Url: https://github.com/pkufool/sherpa-onnx/sessions/f9574840-2100-424c-ac19-017ec182a0df Co-authored-by: pkufool <11765074+pkufool@users.noreply.github.com>
…search-python Implementing offline version of CTC prefix beam search in Python
There was a problem hiding this comment.
🧹 Nitpick comments (1)
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py (1)
1741-1806: Changes look good; consider fixing pre-existing docstring URL.The
max_active_pathsaddition follows the same correct pattern. However, note that the docstring references whisper models documentation (line 1752) instead of wenet models. This is a pre-existing issue, not introduced by this PR, but could be addressed as a quick fix while you're here.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@sherpa-onnx/python/sherpa_onnx/offline_recognizer.py` around lines 1741 - 1806, The docstring in offline_recognizer.py (the classmethod that constructs model_config with OfflineWenetCtcModelConfig / OfflineRecognizerConfig) incorrectly references the Whisper pretrained models URL and wording; update that docstring to point to the appropriate WeNet/Sherpa ONNX wenet pretrained models page and replace occurrences of "whisper" with "wenet" (or the correct model family name) so the link and description match the wenet_ctc model being constructed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@sherpa-onnx/python/sherpa_onnx/offline_recognizer.py`:
- Around line 1741-1806: The docstring in offline_recognizer.py (the classmethod
that constructs model_config with OfflineWenetCtcModelConfig /
OfflineRecognizerConfig) incorrectly references the Whisper pretrained models
URL and wording; update that docstring to point to the appropriate WeNet/Sherpa
ONNX wenet pretrained models page and replace occurrences of "whisper" with
"wenet" (or the correct model family name) so the link and description match the
wenet_ctc model being constructed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3111ab47-b381-425a-85c6-dacb477c9a34
📒 Files selected for processing (2)
python-api-examples/offline-decode-files.pysherpa-onnx/python/sherpa_onnx/offline_recognizer.py
✅ Files skipped from review due to trivial changes (1)
- python-api-examples/offline-decode-files.py
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
python-api-examples/online-decode-files.py (1)
171-178:⚠️ Potential issue | 🟡 MinorValidate
prefix_beam_searchagainst the selected model.This option is now advertised globally, but the transducer branch still passes
args.decoding_methodthrough unchanged.--encoder/--decoder/--joiner --decoding-method=prefix_beam_searchwill fail late instead of being rejected up front in the example script.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python-api-examples/online-decode-files.py` around lines 171 - 178, The example script currently accepts "--decoding-method=prefix_beam_search" globally but does not validate it for transducer models; update the argument validation after parsing (using parser/args) to detect when a transducer model is being used (e.g., presence of args.encoder/args.decoder/args.joiner or whatever branch handles the transducer) and raise a clear error or exit if args.decoding_method == "prefix_beam_search" while the transducer branch will be used; ensure other branches (CTC, non-transducer) still allow prefix_beam_search and keep the transducer branch logic that previously passed args.decoding_method unchanged except for this upfront validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@sherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.cc`:
- Around line 50-59: When expanding prefixes where code pushes new_token into
new_hyp.ys (e.g., in the branches that set new_hyp.log_prob_nb and
new_hyp.log_prob_b), also push the corresponding frame index into
new_hyp.timestamps (use the current frame/time variable used in Decode, e.g., t
or frame_idx); likewise ensure other similar blocks (the ones around the other
push_back sites and the ones noted at 58-63 and 138-140) update timestamps in
lockstep with ys. Finally, when finishing Decode(), copy best_hyp.timestamps
into r.timestamps so the returned Result preserves the token-to-timestamp
correspondence (ensuring tokens.size() == timestamps.size()).
- Around line 24-78: The loop creates new Hypothesis objects and updates
log_prob_b/log_prob_nb but calls next_hyps.Add(...) which uses the default
(non-CTC) merge behavior; change those Add calls to the CTC-aware merge variant
(e.g., Hypotheses::AddCTC or the Add overload that enables CTC merging) so
identical token sequences (same hyp.ys) are merged using CTC rules: combine
log_prob_b and log_prob_nb via log-sum-exp, preserve/update num_trailing_blanks
correctly, and carry the proper lm/context_state; update every place calling
next_hyps.Add(std::move(new_hyp)) in this function (blank case, same-token case,
and update_prefix branch) to use the CTC-aware add so prefix merging is correct.
In `@sherpa-onnx/csrc/online-recognizer.cc`:
- Around line 156-163: The current validation accepts --hotwords-file with
--decoding-method=prefix_beam_search even when a CTC graph is configured,
causing hotword biasing to be ignored; update the check that currently inspects
hotwords_file and decoding_method to also reject when
ctc_fst_decoder_config.graph is set. Specifically, in the validation branch that
looks at hotwords_file and decoding_method (symbols: hotwords_file,
decoding_method, "modified_beam_search", "prefix_beam_search"), add a condition
to fail (log error and return false) if ctc_fst_decoder_config.graph is
non-empty (symbol: ctc_fst_decoder_config.graph) so that supplying hotwords with
a configured CTC FST graph is rejected up front. Ensure the error message
mentions both hotwords_file and the presence of a CTC graph to make the conflict
clear.
---
Outside diff comments:
In `@python-api-examples/online-decode-files.py`:
- Around line 171-178: The example script currently accepts
"--decoding-method=prefix_beam_search" globally but does not validate it for
transducer models; update the argument validation after parsing (using
parser/args) to detect when a transducer model is being used (e.g., presence of
args.encoder/args.decoder/args.joiner or whatever branch handles the transducer)
and raise a clear error or exit if args.decoding_method == "prefix_beam_search"
while the transducer branch will be used; ensure other branches (CTC,
non-transducer) still allow prefix_beam_search and keep the transducer branch
logic that previously passed args.decoding_method unchanged except for this
upfront validation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f929b52d-b6fe-4ee9-8ca3-99fecb1cbce5
📒 Files selected for processing (10)
c-api-examples/decode-file-c-api.cpython-api-examples/online-decode-files.pysherpa-onnx/c-api/c-api.hsherpa-onnx/csrc/CMakeLists.txtsherpa-onnx/csrc/online-ctc-decoder.hsherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.ccsherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.hsherpa-onnx/csrc/online-recognizer-ctc-impl.hsherpa-onnx/csrc/online-recognizer.ccsherpa-onnx/python/sherpa_onnx/online_recognizer.py
✅ Files skipped from review due to trivial changes (2)
- c-api-examples/decode-file-c-api.c
- sherpa-onnx/c-api/c-api.h
🚧 Files skipped from review as they are similar to previous changes (1)
- sherpa-onnx/csrc/CMakeLists.txt
| Hypotheses next_hyps; | ||
| for (auto &hyp : hyps) { | ||
| for (auto k : topk) { | ||
| Hypothesis new_hyp = hyp; | ||
| int32_t new_token = k; | ||
| float log_prob = p_log_probs[k]; | ||
| bool update_prefix = false; | ||
| if (new_token == blank_id) { | ||
| // Case 0: *a + ε => *a | ||
| // *aε + ε => *a | ||
| // Prefix does not change, update log_prob of blank | ||
| new_hyp.log_prob_nb = -std::numeric_limits<float>::infinity(); | ||
| new_hyp.log_prob_b = hyp.LogProb(true) + log_prob; | ||
| new_hyp.num_trailing_blanks = hyp.num_trailing_blanks + 1; | ||
| next_hyps.Add(std::move(new_hyp)); | ||
| } else if (hyp.ys.size() > 0 && hyp.ys.back() == new_token) { | ||
| // Case 1: *a + a => *a | ||
| // Prefix does not change, update log_prob of non_blank | ||
| new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob; | ||
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | ||
| new_hyp.num_trailing_blanks = 0; | ||
| next_hyps.Add(std::move(new_hyp)); | ||
|
|
||
| // Case 2: *aε + a => *aa | ||
| // Prefix changes, update log_prob of blank | ||
| new_hyp = hyp; | ||
| new_hyp.ys.push_back(new_token); | ||
| new_hyp.log_prob_nb = hyp.log_prob_b + log_prob; | ||
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | ||
| new_hyp.num_trailing_blanks = 0; | ||
| update_prefix = true; | ||
| } else { | ||
| // Case 3: *a + b => *ab, *aε + b => *ab | ||
| // Prefix changes, update log_prob of non_blank | ||
| new_hyp.ys.push_back(new_token); | ||
| new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob; | ||
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | ||
| new_hyp.num_trailing_blanks = 0; | ||
| update_prefix = true; | ||
| } | ||
|
|
||
| if (update_prefix) { | ||
| float lm_log_prob = hyp.lm_log_prob; | ||
| if (context_graph != nullptr && hyp.context_state != nullptr) { | ||
| auto context_res = | ||
| context_graph->ForwardOneStep(hyp.context_state, new_token); | ||
| lm_log_prob = lm_log_prob + std::get<0>(context_res); | ||
| new_hyp.context_state = std::get<1>(context_res); | ||
| } | ||
| new_hyp.lm_log_prob = lm_log_prob; | ||
| next_hyps.Add(std::move(new_hyp)); | ||
| } | ||
| } | ||
| } | ||
| return next_hyps.GetTopK(max_active_paths, false, true); |
There was a problem hiding this comment.
Use CTC-aware hypothesis merging here.
StepWorker() updates log_prob_b/log_prob_nb, but every next_hyps.Add(...) call still uses the default non-CTC merge path. That breaks prefix merging whenever the same token sequence is reached through different blank/non-blank transitions.
Suggested fix
- next_hyps.Add(std::move(new_hyp));
+ next_hyps.Add(std::move(new_hyp), /*use_ctc=*/true);
...
- next_hyps.Add(std::move(new_hyp));
+ next_hyps.Add(std::move(new_hyp), /*use_ctc=*/true);
...
- next_hyps.Add(std::move(new_hyp));
+ next_hyps.Add(std::move(new_hyp), /*use_ctc=*/true);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Hypotheses next_hyps; | |
| for (auto &hyp : hyps) { | |
| for (auto k : topk) { | |
| Hypothesis new_hyp = hyp; | |
| int32_t new_token = k; | |
| float log_prob = p_log_probs[k]; | |
| bool update_prefix = false; | |
| if (new_token == blank_id) { | |
| // Case 0: *a + ε => *a | |
| // *aε + ε => *a | |
| // Prefix does not change, update log_prob of blank | |
| new_hyp.log_prob_nb = -std::numeric_limits<float>::infinity(); | |
| new_hyp.log_prob_b = hyp.LogProb(true) + log_prob; | |
| new_hyp.num_trailing_blanks = hyp.num_trailing_blanks + 1; | |
| next_hyps.Add(std::move(new_hyp)); | |
| } else if (hyp.ys.size() > 0 && hyp.ys.back() == new_token) { | |
| // Case 1: *a + a => *a | |
| // Prefix does not change, update log_prob of non_blank | |
| new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob; | |
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | |
| new_hyp.num_trailing_blanks = 0; | |
| next_hyps.Add(std::move(new_hyp)); | |
| // Case 2: *aε + a => *aa | |
| // Prefix changes, update log_prob of blank | |
| new_hyp = hyp; | |
| new_hyp.ys.push_back(new_token); | |
| new_hyp.log_prob_nb = hyp.log_prob_b + log_prob; | |
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | |
| new_hyp.num_trailing_blanks = 0; | |
| update_prefix = true; | |
| } else { | |
| // Case 3: *a + b => *ab, *aε + b => *ab | |
| // Prefix changes, update log_prob of non_blank | |
| new_hyp.ys.push_back(new_token); | |
| new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob; | |
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | |
| new_hyp.num_trailing_blanks = 0; | |
| update_prefix = true; | |
| } | |
| if (update_prefix) { | |
| float lm_log_prob = hyp.lm_log_prob; | |
| if (context_graph != nullptr && hyp.context_state != nullptr) { | |
| auto context_res = | |
| context_graph->ForwardOneStep(hyp.context_state, new_token); | |
| lm_log_prob = lm_log_prob + std::get<0>(context_res); | |
| new_hyp.context_state = std::get<1>(context_res); | |
| } | |
| new_hyp.lm_log_prob = lm_log_prob; | |
| next_hyps.Add(std::move(new_hyp)); | |
| } | |
| } | |
| } | |
| return next_hyps.GetTopK(max_active_paths, false, true); | |
| Hypotheses next_hyps; | |
| for (auto &hyp : hyps) { | |
| for (auto k : topk) { | |
| Hypothesis new_hyp = hyp; | |
| int32_t new_token = k; | |
| float log_prob = p_log_probs[k]; | |
| bool update_prefix = false; | |
| if (new_token == blank_id) { | |
| // Case 0: *a + ε => *a | |
| // *aε + ε => *a | |
| // Prefix does not change, update log_prob of blank | |
| new_hyp.log_prob_nb = -std::numeric_limits<float>::infinity(); | |
| new_hyp.log_prob_b = hyp.LogProb(true) + log_prob; | |
| new_hyp.num_trailing_blanks = hyp.num_trailing_blanks + 1; | |
| next_hyps.Add(std::move(new_hyp), /*use_ctc=*/true); | |
| } else if (hyp.ys.size() > 0 && hyp.ys.back() == new_token) { | |
| // Case 1: *a + a => *a | |
| // Prefix does not change, update log_prob of non_blank | |
| new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob; | |
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | |
| new_hyp.num_trailing_blanks = 0; | |
| next_hyps.Add(std::move(new_hyp), /*use_ctc=*/true); | |
| // Case 2: *aε + a => *aa | |
| // Prefix changes, update log_prob of blank | |
| new_hyp = hyp; | |
| new_hyp.ys.push_back(new_token); | |
| new_hyp.log_prob_nb = hyp.log_prob_b + log_prob; | |
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | |
| new_hyp.num_trailing_blanks = 0; | |
| update_prefix = true; | |
| } else { | |
| // Case 3: *a + b => *ab, *aε + b => *ab | |
| // Prefix changes, update log_prob of non_blank | |
| new_hyp.ys.push_back(new_token); | |
| new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob; | |
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | |
| new_hyp.num_trailing_blanks = 0; | |
| update_prefix = true; | |
| } | |
| if (update_prefix) { | |
| float lm_log_prob = hyp.lm_log_prob; | |
| if (context_graph != nullptr && hyp.context_state != nullptr) { | |
| auto context_res = | |
| context_graph->ForwardOneStep(hyp.context_state, new_token); | |
| lm_log_prob = lm_log_prob + std::get<0>(context_res); | |
| new_hyp.context_state = std::get<1>(context_res); | |
| } | |
| new_hyp.lm_log_prob = lm_log_prob; | |
| next_hyps.Add(std::move(new_hyp), /*use_ctc=*/true); | |
| } | |
| } | |
| } | |
| return next_hyps.GetTopK(max_active_paths, false, true); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.cc` around lines 24 -
78, The loop creates new Hypothesis objects and updates log_prob_b/log_prob_nb
but calls next_hyps.Add(...) which uses the default (non-CTC) merge behavior;
change those Add calls to the CTC-aware merge variant (e.g., Hypotheses::AddCTC
or the Add overload that enables CTC merging) so identical token sequences (same
hyp.ys) are merged using CTC rules: combine log_prob_b and log_prob_nb via
log-sum-exp, preserve/update num_trailing_blanks correctly, and carry the proper
lm/context_state; update every place calling next_hyps.Add(std::move(new_hyp))
in this function (blank case, same-token case, and update_prefix branch) to use
the CTC-aware add so prefix merging is correct.
| new_hyp.ys.push_back(new_token); | ||
| new_hyp.log_prob_nb = hyp.log_prob_b + log_prob; | ||
| new_hyp.log_prob_b = -std::numeric_limits<float>::infinity(); | ||
| new_hyp.num_trailing_blanks = 0; | ||
| update_prefix = true; | ||
| } else { | ||
| // Case 3: *a + b => *ab, *aε + b => *ab | ||
| // Prefix changes, update log_prob of non_blank | ||
| new_hyp.ys.push_back(new_token); | ||
| new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob; |
There was a problem hiding this comment.
Record and return token timestamps.
The prefix-expansion paths append to ys, but never append the corresponding frame index, and Decode() never writes best_hyp.timestamps back into r.timestamps. That leaves online prefix-beam results without timing metadata and breaks the tokens.size() == timestamps.size() contract.
Suggested direction
-static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
- std::vector<Hypothesis> &hyps,
- int32_t blank_id, int32_t vocab_size,
- int32_t max_active_paths,
- const ContextGraph *context_graph) {
+static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
+ std::vector<Hypothesis> &hyps,
+ int32_t blank_id, int32_t vocab_size,
+ int32_t max_active_paths,
+ int32_t frame_index,
+ const ContextGraph *context_graph) {
...
new_hyp.ys.push_back(new_token);
+ new_hyp.timestamps.push_back(frame_index);
new_hyp.log_prob_nb = hyp.log_prob_b + log_prob;
...
new_hyp.ys.push_back(new_token);
+ new_hyp.timestamps.push_back(frame_index);
new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob;
...
- cur[b] = StepWorker(p, cur[b], blank_id_, vocab_size,
- max_active_paths_, context_graphs[b].get());
+ cur[b] = StepWorker(p, cur[b], blank_id_, vocab_size,
+ max_active_paths_,
+ (*results)[b].frame_offset + t,
+ context_graphs[b].get());
...
r.tokens = best_hyp.ys;
+ r.timestamps = best_hyp.timestamps;
r.frame_offset += num_frames;Also applies to: 58-63, 138-140
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.cc` around lines 50 -
59, When expanding prefixes where code pushes new_token into new_hyp.ys (e.g.,
in the branches that set new_hyp.log_prob_nb and new_hyp.log_prob_b), also push
the corresponding frame index into new_hyp.timestamps (use the current
frame/time variable used in Decode, e.g., t or frame_idx); likewise ensure other
similar blocks (the ones around the other push_back sites and the ones noted at
58-63 and 138-140) update timestamps in lockstep with ys. Finally, when
finishing Decode(), copy best_hyp.timestamps into r.timestamps so the returned
Result preserves the token-to-timestamp correspondence (ensuring tokens.size()
== timestamps.size()).
| if (!hotwords_file.empty() && decoding_method != "modified_beam_search" && | ||
| decoding_method != "prefix_beam_search") { | ||
| SHERPA_ONNX_LOGE( | ||
| "Please use --decoding-method=modified_beam_search if you" | ||
| "Please use --decoding-method=modified_beam_search or " | ||
| "--decoding-method=prefix_beam_search if you" | ||
| " provide --hotwords-file. Given --decoding-method=%s", | ||
| decoding_method.c_str()); | ||
| return false; |
There was a problem hiding this comment.
Reject hotwords_file when a CTC graph is configured.
This new validation now accepts hotwords_file with prefix_beam_search, but in sherpa-onnx/csrc/online-recognizer-ctc-impl.h the ctc_fst_decoder_config.graph path is chosen before prefix_beam_search, so the hotword biasing is silently ignored. Please fail fast for that combination instead of accepting a no-op config.
Suggested validation
+ if (!hotwords_file.empty() && !ctc_fst_decoder_config.graph.empty()) {
+ SHERPA_ONNX_LOGE(
+ "--hotwords-file is not supported when a CTC graph is configured.");
+ return false;
+ }
+
if (!hotwords_file.empty() && decoding_method != "modified_beam_search" &&
decoding_method != "prefix_beam_search") {📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (!hotwords_file.empty() && decoding_method != "modified_beam_search" && | |
| decoding_method != "prefix_beam_search") { | |
| SHERPA_ONNX_LOGE( | |
| "Please use --decoding-method=modified_beam_search if you" | |
| "Please use --decoding-method=modified_beam_search or " | |
| "--decoding-method=prefix_beam_search if you" | |
| " provide --hotwords-file. Given --decoding-method=%s", | |
| decoding_method.c_str()); | |
| return false; | |
| if (!hotwords_file.empty() && !ctc_fst_decoder_config.graph.empty()) { | |
| SHERPA_ONNX_LOGE( | |
| "--hotwords-file is not supported when a CTC graph is configured."); | |
| return false; | |
| } | |
| if (!hotwords_file.empty() && decoding_method != "modified_beam_search" && | |
| decoding_method != "prefix_beam_search") { | |
| SHERPA_ONNX_LOGE( | |
| "Please use --decoding-method=modified_beam_search or " | |
| "--decoding-method=prefix_beam_search if you" | |
| " provide --hotwords-file. Given --decoding-method=%s", | |
| decoding_method.c_str()); | |
| return false; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@sherpa-onnx/csrc/online-recognizer.cc` around lines 156 - 163, The current
validation accepts --hotwords-file with --decoding-method=prefix_beam_search
even when a CTC graph is configured, causing hotword biasing to be ignored;
update the check that currently inspects hotwords_file and decoding_method to
also reject when ctc_fst_decoder_config.graph is set. Specifically, in the
validation branch that looks at hotwords_file and decoding_method (symbols:
hotwords_file, decoding_method, "modified_beam_search", "prefix_beam_search"),
add a condition to fail (log error and return false) if
ctc_fst_decoder_config.graph is non-empty (symbol: ctc_fst_decoder_config.graph)
so that supplying hotwords with a configured CTC FST graph is rejected up front.
Ensure the error message mentions both hotwords_file and the presence of a CTC
graph to make the conflict clear.
This PR implements the core part (c++/python) of CTC prefix beam search related decoding methods, including hotwords and rnnlm shallow fussion.
BTW we release our recent progress on CTC models, see https://arxiv.org/pdf/2410.05101 for details.
Summary by CodeRabbit
New Features
Documentation