Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,23 @@ asr:
device_id: 0 # GPU device ID
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
use_amp: false # Enable Automatic Mixed Precision
ngram_lm_model: "" # Path to ngram language model
ngram_lm_alpha: 0.0 # Alpha for language model
decoding:
strategy: "greedy_batch"
preserve_alignments: false
fused_batch_size: -1
greedy:
use_cuda_graph_decoder: true
max_symbols: 10
# n-gram LM
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
ngram_lm_alpha: 0.0 # Weight of the LM model
# phrase boosting
boosting_tree:
model_path: null # The path to built '.nemo' boosting tree model
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
boosting_tree_alpha: 0.0


# ==========================================
Expand Down
18 changes: 17 additions & 1 deletion examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,23 @@ asr:
device_id: 0 # GPU device ID
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
use_amp: true # Enable Automatic Mixed Precision

decoding:
strategy: "greedy_batch"
preserve_alignments: false
fused_batch_size: -1
greedy:
use_cuda_graph_decoder: false # Disabled due to issues with decoding
max_symbols: 10
# n-gram LM
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
ngram_lm_alpha: 0.0 # Weight of the LM model
# phrase boosting
boosting_tree:
model_path: null # The path to built '.nemo' boosting tree model
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
boosting_tree_alpha: 0.0 # Weight of the boosting tree

# ==========================================
# Inverse Text Normalization Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from omegaconf.dictconfig import DictConfig
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.inference.factory.base_builder import BaseBuilder
from nemo.collections.asr.inference.pipelines.buffered_ctc_pipeline import BufferedCTCPipeline
Expand Down Expand Up @@ -54,25 +54,9 @@ def get_rnnt_decoding_cfg(cls, cfg: DictConfig) -> RNNTDecodingConfig:
Returns:
(RNNTDecodingConfig) Decoding config
"""
decoding_cfg = RNNTDecodingConfig()

# greedy_batch decoding strategy required for stateless streaming
decoding_cfg.strategy = "greedy_batch"

# required to compute the middle token for transducers.
decoding_cfg.preserve_alignments = False

# temporarily stop fused batch during inference.
decoding_cfg.fused_batch_size = -1

# return and write the best hypothesis only
decoding_cfg.beam.return_best_hypothesis = True

# setup ngram language model
if hasattr(cfg.asr, "ngram_lm_model") and cfg.asr.ngram_lm_model != "":
decoding_cfg.greedy.ngram_lm_model = cfg.asr.ngram_lm_model
decoding_cfg.greedy.ngram_lm_alpha = cfg.asr.ngram_lm_alpha

base_cfg_structured = OmegaConf.structured(RNNTDecodingConfig)
base_cfg = OmegaConf.create(OmegaConf.to_container(base_cfg_structured))
decoding_cfg = OmegaConf.merge(base_cfg, cfg.asr.decoding)
return decoding_cfg

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from omegaconf.dictconfig import DictConfig
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.inference.factory.base_builder import BaseBuilder
from nemo.collections.asr.inference.pipelines.cache_aware_ctc_pipeline import CacheAwareCTCPipeline
Expand Down Expand Up @@ -48,18 +48,15 @@ def build(cls, cfg: DictConfig) -> CacheAwareCTCPipeline | CacheAwareRNNTPipelin
raise ValueError("Invalid asr decoding type for cache aware streaming. Need to be one of ['CTC', 'RNNT']")

@classmethod
def get_rnnt_decoding_cfg(cls) -> RNNTDecodingConfig:
def get_rnnt_decoding_cfg(cls, cfg: DictConfig) -> RNNTDecodingConfig:
"""
Get the decoding config for the RNNT pipeline.
Returns:
(RNNTDecodingConfig) Decoding config
"""
decoding_cfg = RNNTDecodingConfig()
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.preserve_alignments = False
decoding_cfg.greedy.use_cuda_graph_decoder = False
decoding_cfg.greedy.max_symbols = 10
decoding_cfg.fused_batch_size = -1
base_cfg_structured = OmegaConf.structured(RNNTDecodingConfig)
base_cfg = OmegaConf.create(OmegaConf.to_container(base_cfg_structured))
decoding_cfg = OmegaConf.merge(base_cfg, cfg.asr.decoding)
return decoding_cfg

@classmethod
Expand All @@ -84,7 +81,7 @@ def build_cache_aware_rnnt_pipeline(cls, cfg: DictConfig) -> CacheAwareRNNTPipel
Returns CacheAwareRNNTPipeline object
"""
# building ASR model
decoding_cfg = cls.get_rnnt_decoding_cfg()
decoding_cfg = cls.get_rnnt_decoding_cfg(cfg)
asr_model = cls._build_asr(cfg, decoding_cfg)

# building ITN model
Expand Down
Loading