Skip to content

Commit 3408053

Browse files
authored
Merge pull request #2 from pkufool/copilot/implement-ctc-prefix-beam-search-python
Implementing offline version of CTC prefix beam search in Python
2 parents a2b64a2 + 0eff9c7 commit 3408053

2 files changed

Lines changed: 58 additions & 10 deletions

File tree

python-api-examples/offline-decode-files.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,10 @@ def get_args():
286286
"--decoding-method",
287287
type=str,
288288
default="greedy_search",
289-
help="Valid values are greedy_search and modified_beam_search",
289+
help=(
290+
"Valid values are greedy_search, modified_beam_search "
291+
"(for transducer), and prefix_beam_search (for CTC)"
292+
),
290293
)
291294

292295
parser.add_argument(

sherpa-onnx/python/sherpa_onnx/offline_recognizer.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def from_telespeech_ctc(
588588
sample_rate: int = 16000,
589589
feature_dim: int = 40,
590590
decoding_method: str = "greedy_search",
591+
max_active_paths: int = 4,
591592
debug: bool = False,
592593
provider: str = "cpu",
593594
rule_fsts: str = "",
@@ -619,7 +620,10 @@ def from_telespeech_ctc(
619620
Dimension of the feature used to train the model. It is ignored
620621
and is hard-coded in C++ to 40.
621622
decoding_method:
622-
Valid values are greedy_search.
623+
Valid values are greedy_search and prefix_beam_search.
624+
max_active_paths:
625+
Maximum number of active paths to keep. Used only when
626+
decoding_method is prefix_beam_search.
623627
debug:
624628
True to show debug messages.
625629
provider:
@@ -649,6 +653,7 @@ def from_telespeech_ctc(
649653
feat_config=feat_config,
650654
model_config=model_config,
651655
decoding_method=decoding_method,
656+
max_active_paths=max_active_paths,
652657
rule_fsts=rule_fsts,
653658
rule_fars=rule_fars,
654659
hr=HomophoneReplacerConfig(
@@ -668,6 +673,7 @@ def from_dolphin_ctc(
668673
sample_rate: int = 16000,
669674
feature_dim: int = 80,
670675
decoding_method: str = "greedy_search",
676+
max_active_paths: int = 4,
671677
debug: bool = False,
672678
provider: str = "cpu",
673679
rule_fsts: str = "",
@@ -697,7 +703,10 @@ def from_dolphin_ctc(
697703
feature_dim:
698704
Dimension of the feature used to train the model.
699705
decoding_method:
700-
Valid values are greedy_search.
706+
Valid values are greedy_search and prefix_beam_search.
707+
max_active_paths:
708+
Maximum number of active paths to keep. Used only when
709+
decoding_method is prefix_beam_search.
701710
debug:
702711
True to show debug messages.
703712
provider:
@@ -727,6 +736,7 @@ def from_dolphin_ctc(
727736
feat_config=feat_config,
728737
model_config=model_config,
729738
decoding_method=decoding_method,
739+
max_active_paths=max_active_paths,
730740
rule_fsts=rule_fsts,
731741
rule_fars=rule_fars,
732742
hr=HomophoneReplacerConfig(
@@ -746,6 +756,7 @@ def from_fire_red_asr_ctc(
746756
tokens: str,
747757
num_threads: int = 1,
748758
decoding_method: str = "greedy_search",
759+
max_active_paths: int = 4,
749760
debug: bool = False,
750761
provider: str = "cpu",
751762
):
@@ -766,7 +777,10 @@ def from_fire_red_asr_ctc(
766777
num_threads:
767778
Number of threads for neural network computation.
768779
decoding_method:
769-
The only supported decoding method is greedy_search.
780+
Valid values are greedy_search and prefix_beam_search.
781+
max_active_paths:
782+
Maximum number of active paths to keep. Used only when
783+
decoding_method is prefix_beam_search.
770784
debug:
771785
True to show debug messages.
772786
provider:
@@ -784,6 +798,7 @@ def from_fire_red_asr_ctc(
784798
recognizer_config = OfflineRecognizerConfig(
785799
model_config=model_config,
786800
decoding_method=decoding_method,
801+
max_active_paths=max_active_paths,
787802
)
788803
self.recognizer = _Recognizer(recognizer_config)
789804
self.config = recognizer_config
@@ -796,6 +811,7 @@ def from_medasr_ctc(
796811
tokens: str,
797812
num_threads: int = 1,
798813
decoding_method: str = "greedy_search",
814+
max_active_paths: int = 4,
799815
debug: bool = False,
800816
provider: str = "cpu",
801817
):
@@ -816,7 +832,10 @@ def from_medasr_ctc(
816832
num_threads:
817833
Number of threads for neural network computation.
818834
decoding_method:
819-
The only supported decoding method is greedy_search.
835+
Valid values are greedy_search and prefix_beam_search.
836+
max_active_paths:
837+
Maximum number of active paths to keep. Used only when
838+
decoding_method is prefix_beam_search.
820839
debug:
821840
True to show debug messages.
822841
provider:
@@ -834,6 +853,7 @@ def from_medasr_ctc(
834853
recognizer_config = OfflineRecognizerConfig(
835854
model_config=model_config,
836855
decoding_method=decoding_method,
856+
max_active_paths=max_active_paths,
837857
)
838858
self.recognizer = _Recognizer(recognizer_config)
839859
self.config = recognizer_config
@@ -846,6 +866,7 @@ def from_omnilingual_asr_ctc(
846866
tokens: str,
847867
num_threads: int = 1,
848868
decoding_method: str = "greedy_search",
869+
max_active_paths: int = 4,
849870
debug: bool = False,
850871
provider: str = "cpu",
851872
):
@@ -866,7 +887,10 @@ def from_omnilingual_asr_ctc(
866887
num_threads:
867888
Number of threads for neural network computation.
868889
decoding_method:
869-
The only supported decoding method is greedy_search.
890+
Valid values are greedy_search and prefix_beam_search.
891+
max_active_paths:
892+
Maximum number of active paths to keep. Used only when
893+
decoding_method is prefix_beam_search.
870894
debug:
871895
True to show debug messages.
872896
provider:
@@ -884,6 +908,7 @@ def from_omnilingual_asr_ctc(
884908
recognizer_config = OfflineRecognizerConfig(
885909
model_config=model_config,
886910
decoding_method=decoding_method,
911+
max_active_paths=max_active_paths,
887912
)
888913
self.recognizer = _Recognizer(recognizer_config)
889914
self.config = recognizer_config
@@ -898,6 +923,7 @@ def from_zipformer_ctc(
898923
sample_rate: int = 16000,
899924
feature_dim: int = 80,
900925
decoding_method: str = "greedy_search",
926+
max_active_paths: int = 4,
901927
debug: bool = False,
902928
provider: str = "cpu",
903929
rule_fsts: str = "",
@@ -928,7 +954,10 @@ def from_zipformer_ctc(
928954
feature_dim:
929955
Dimension of the feature used to train the model.
930956
decoding_method:
931-
Valid values are greedy_search.
957+
Valid values are greedy_search and prefix_beam_search.
958+
max_active_paths:
959+
Maximum number of active paths to keep. Used only when
960+
decoding_method is prefix_beam_search.
932961
debug:
933962
True to show debug messages.
934963
provider:
@@ -958,6 +987,7 @@ def from_zipformer_ctc(
958987
feat_config=feat_config,
959988
model_config=model_config,
960989
decoding_method=decoding_method,
990+
max_active_paths=max_active_paths,
961991
rule_fsts=rule_fsts,
962992
rule_fars=rule_fars,
963993
hr=HomophoneReplacerConfig(
@@ -979,6 +1009,7 @@ def from_nemo_ctc(
9791009
sample_rate: int = 16000,
9801010
feature_dim: int = 80,
9811011
decoding_method: str = "greedy_search",
1012+
max_active_paths: int = 4,
9821013
debug: bool = False,
9831014
provider: str = "cpu",
9841015
rule_fsts: str = "",
@@ -1009,7 +1040,10 @@ def from_nemo_ctc(
10091040
feature_dim:
10101041
Dimension of the feature used to train the model.
10111042
decoding_method:
1012-
Valid values are greedy_search.
1043+
Valid values are greedy_search and prefix_beam_search.
1044+
max_active_paths:
1045+
Maximum number of active paths to keep. Used only when
1046+
decoding_method is prefix_beam_search.
10131047
debug:
10141048
True to show debug messages.
10151049
provider:
@@ -1040,6 +1074,7 @@ def from_nemo_ctc(
10401074
feat_config=feat_config,
10411075
model_config=model_config,
10421076
decoding_method=decoding_method,
1077+
max_active_paths=max_active_paths,
10431078
rule_fsts=rule_fsts,
10441079
rule_fars=rule_fars,
10451080
hr=HomophoneReplacerConfig(
@@ -1617,6 +1652,7 @@ def from_tdnn_ctc(
16171652
sample_rate: int = 8000,
16181653
feature_dim: int = 23,
16191654
decoding_method: str = "greedy_search",
1655+
max_active_paths: int = 4,
16201656
debug: bool = False,
16211657
provider: str = "cpu",
16221658
rule_fsts: str = "",
@@ -1646,7 +1682,10 @@ def from_tdnn_ctc(
16461682
feature_dim:
16471683
Dimension of the feature used to train the model.
16481684
decoding_method:
1649-
Valid values are greedy_search.
1685+
Valid values are greedy_search and prefix_beam_search.
1686+
max_active_paths:
1687+
Maximum number of active paths to keep. Used only when
1688+
decoding_method is prefix_beam_search.
16501689
debug:
16511690
True to show debug messages.
16521691
provider:
@@ -1677,6 +1716,7 @@ def from_tdnn_ctc(
16771716
feat_config=feat_config,
16781717
model_config=model_config,
16791718
decoding_method=decoding_method,
1719+
max_active_paths=max_active_paths,
16801720
rule_fsts=rule_fsts,
16811721
rule_fars=rule_fars,
16821722
hr=HomophoneReplacerConfig(
@@ -1698,6 +1738,7 @@ def from_wenet_ctc(
16981738
sample_rate: int = 16000,
16991739
feature_dim: int = 80,
17001740
decoding_method: str = "greedy_search",
1741+
max_active_paths: int = 4,
17011742
debug: bool = False,
17021743
provider: str = "cpu",
17031744
rule_fsts: str = "",
@@ -1728,7 +1769,10 @@ def from_wenet_ctc(
17281769
feature_dim:
17291770
Dimension of the feature used to train the model.
17301771
decoding_method:
1731-
Valid values are greedy_search.
1772+
Valid values are greedy_search and prefix_beam_search.
1773+
max_active_paths:
1774+
Maximum number of active paths to keep. Used only when
1775+
decoding_method is prefix_beam_search.
17321776
debug:
17331777
True to show debug messages.
17341778
provider:
@@ -1759,6 +1803,7 @@ def from_wenet_ctc(
17591803
feat_config=feat_config,
17601804
model_config=model_config,
17611805
decoding_method=decoding_method,
1806+
max_active_paths=max_active_paths,
17621807
rule_fsts=rule_fsts,
17631808
rule_fars=rule_fars,
17641809
hr=HomophoneReplacerConfig(

0 commit comments

Comments
 (0)