Skip to content

Commit 8b75d4c

Browse files
committed
more refine
1 parent 8d477bd commit 8b75d4c

4 files changed

Lines changed: 36 additions & 17 deletions

File tree

c-api-examples/decode-file-c-api.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ static struct cag_option options[] = {
6565
.access_name = "hotwords-score",
6666
.value_name = "hotwords-score",
6767
.description = "The bonus score for each token in hotwords. Used only "
68-
"when decoding_method is modified_beam_search"},
68+
"when decoding_method is modified_beam_search or "
69+
"prefix_beam_search"},
6970
};
7071

7172
const char *kUsage =
@@ -80,7 +81,8 @@ const char *kUsage =
8081
" /path/to/foo.wav\n"
8182
"\n\n"
8283
"Default num_threads is 1.\n"
83-
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
84+
"Valid decoding_method: greedy_search (default), modified_beam_search,\n"
85+
" prefix_beam_search (for CTC models)\n\n"
8486
"Valid provider: cpu (default), cuda, coreml\n\n"
8587
"Please refer to \n"
8688
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/"

python-api-examples/online-decode-files.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def get_args():
172172
"--decoding-method",
173173
type=str,
174174
default="greedy_search",
175-
help="Valid values are greedy_search and modified_beam_search",
175+
help=(
176+
"Valid values are greedy_search, modified_beam_search "
177+
"(for transducer), and prefix_beam_search (for CTC)"
178+
),
176179
)
177180

178181
parser.add_argument(
@@ -368,7 +371,12 @@ def main():
368371
provider=args.provider,
369372
sample_rate=16000,
370373
feature_dim=80,
371-
decoding_method="greedy_search",
374+
decoding_method=args.decoding_method,
375+
max_active_paths=args.max_active_paths,
376+
hotwords_file=args.hotwords_file,
377+
hotwords_score=args.hotwords_score,
378+
modeling_unit=args.modeling_unit,
379+
bpe_vocab=args.bpe_vocab,
372380
)
373381
elif args.paraformer_encoder:
374382
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
@@ -391,7 +399,12 @@ def main():
391399
provider=args.provider,
392400
sample_rate=16000,
393401
feature_dim=80,
394-
decoding_method="greedy_search",
402+
decoding_method=args.decoding_method,
403+
max_active_paths=args.max_active_paths,
404+
hotwords_file=args.hotwords_file,
405+
hotwords_score=args.hotwords_score,
406+
modeling_unit=args.modeling_unit,
407+
bpe_vocab=args.bpe_vocab,
395408
)
396409
else:
397410
raise ValueError("Please provide a model")

sherpa-onnx/c-api/c-api.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,11 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
340340
/** Streaming model configuration. */
341341
SherpaOnnxOnlineModelConfig model_config;
342342

343-
/** Decoding method, for example "greedy_search" or "modified_beam_search". */
343+
/** Decoding method: "greedy_search", "modified_beam_search" (transducer),
344+
* or "prefix_beam_search" (CTC). */
344345
const char *decoding_method;
345346

346-
/** Number of active paths for modified beam search. */
347+
/** Number of active paths for modified_beam_search / prefix_beam_search. */
347348
int32_t max_active_paths;
348349

349350
/** Set to non-zero to enable endpoint detection. */
@@ -358,7 +359,8 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
358359
/** Endpoint rule 3 utterance-length threshold in seconds. */
359360
float rule3_min_utterance_length;
360361

361-
/** Path to a hotwords file. */
362+
/** Path to a hotwords file. Used with modified_beam_search or
363+
* prefix_beam_search. */
362364
const char *hotwords_file;
363365

364366
/** Bonus score added to each hotword token during decoding. */
@@ -1160,12 +1162,14 @@ typedef struct SherpaOnnxOfflineRecognizerConfig {
11601162
/** Optional language model configuration. */
11611163
SherpaOnnxOfflineLMConfig lm_config;
11621164

1163-
/** Decoding method, for example "greedy_search" or "modified_beam_search". */
1165+
/** Decoding method: "greedy_search", "modified_beam_search" (transducer),
1166+
* or "prefix_beam_search" (CTC). */
11641167
const char *decoding_method;
1165-
/** Number of active paths for modified beam search. */
1168+
/** Number of active paths for modified_beam_search / prefix_beam_search. */
11661169
int32_t max_active_paths;
11671170

1168-
/** Path to a hotwords file. */
1171+
/** Path to a hotwords file. Used with modified_beam_search or
1172+
* prefix_beam_search. */
11691173
const char *hotwords_file;
11701174

11711175
/** Bonus score added to each hotword token. */

sherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// sherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.cc
22
//
3-
// Copyright (c) 2024 Xiaomi Corporation
3+
// Copyright (c) 2026 Xiaomi Corporation
44

55
#include "sherpa-onnx/csrc/online-ctc-prefix-beam-search-decoder.h"
66

@@ -34,12 +34,14 @@ static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
3434
// Prefix does not change, update log_prob of blank
3535
new_hyp.log_prob_nb = -std::numeric_limits<float>::infinity();
3636
new_hyp.log_prob_b = hyp.LogProb(true) + log_prob;
37+
new_hyp.num_trailing_blanks = hyp.num_trailing_blanks + 1;
3738
next_hyps.Add(std::move(new_hyp));
3839
} else if (hyp.ys.size() > 0 && hyp.ys.back() == new_token) {
3940
// Case 1: *a + a => *a
4041
// Prefix does not change, update log_prob of non_blank
4142
new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob;
4243
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
44+
new_hyp.num_trailing_blanks = 0;
4345
next_hyps.Add(std::move(new_hyp));
4446

4547
// Case 2: *aε + a => *aa
@@ -48,13 +50,15 @@ static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
4850
new_hyp.ys.push_back(new_token);
4951
new_hyp.log_prob_nb = hyp.log_prob_b + log_prob;
5052
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
53+
new_hyp.num_trailing_blanks = 0;
5154
update_prefix = true;
5255
} else {
5356
// Case 3: *a + b => *ab, *aε + b => *ab
5457
// Prefix changes, update log_prob of non_blank
5558
new_hyp.ys.push_back(new_token);
5659
new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob;
5760
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
61+
new_hyp.num_trailing_blanks = 0;
5862
update_prefix = true;
5963
}
6064

@@ -133,11 +137,7 @@ void OnlineCtcPrefixBeamSearchDecoder::Decode(
133137

134138
r.tokens = best_hyp.ys;
135139
r.frame_offset += num_frames;
136-
137-
// Count trailing blanks for endpointing
138-
if (!best_hyp.ys.empty()) {
139-
r.num_trailing_blanks = 0;
140-
}
140+
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
141141
}
142142
}
143143

0 commit comments

Comments
 (0)