Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
b4c33dd
Fused biasing ids in interface
artbataev Nov 17, 2025
0e8f25d
Bugfix
artbataev Nov 17, 2025
066d5d3
Fused model: non-batched reference implementation
artbataev Nov 17, 2025
9d65743
Change api
artbataev Nov 19, 2025
28aca3d
Temporary bugfix
artbataev Nov 20, 2025
fe42d8a
Allow partial hypothesis without dec_state
artbataev Nov 20, 2025
ca2cf4b
Multi-model biasing: implement with RNN-T
artbataev Nov 21, 2025
a290f32
Multi-model biasing: implement for TDT, fix RNN-T
artbataev Nov 21, 2025
e288228
Add and remove boosting model from multi-model: move logic to Hypothesis
artbataev Nov 22, 2025
8d36596
Clean up per-hyp biasing config
artbataev Nov 26, 2025
b8df95e
Auto manage biasing requests in transcribe
artbataev Nov 26, 2025
e74f7b1
Disable CUDA graphs with per-stream biasing
artbataev Nov 26, 2025
558dba2
Support specifying lang per phrase
artbataev Nov 26, 2025
972fc29
Fix customization options
artbataev Nov 26, 2025
352027f
Fix streaming decoding
artbataev Nov 26, 2025
e4e7fc3
Fix streaming decoding
artbataev Nov 26, 2025
ac88aa3
Fix type
artbataev Nov 26, 2025
768e51a
Implement custom requests in nemo inference
artbataev Nov 26, 2025
e24fc99
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Nov 26, 2025
7209d63
Use biasing options from state instead of request
artbataev Nov 26, 2025
736f942
Fix f-strings
artbataev Nov 26, 2025
e5f9c71
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Nov 28, 2025
d8649b7
Expose parameter: enable_per_stream_biasing
artbataev Nov 28, 2025
285ac29
key_phrase_items_list: use parameter in config
artbataev Nov 28, 2025
59139ef
More documentation
artbataev Nov 28, 2025
3680f07
Clean up biasing_cfg management
artbataev Nov 28, 2025
8aea8ad
Specify todo
artbataev Nov 28, 2025
2b20d8e
Clean up
artbataev Nov 28, 2025
4c55a79
Stubs for optimized model
artbataev Nov 28, 2025
5875202
Multi-model: implement adding model
artbataev Dec 1, 2025
1c04627
Multi-model: stubs for advance
artbataev Dec 1, 2025
6840f0f
Multi-model: stubs for advance using Triton
artbataev Dec 1, 2025
494c4c1
Multi-model: implement Triton kernel
artbataev Dec 1, 2025
db84e03
Multi-model: implement advance in Pytorch
artbataev Dec 1, 2025
225a644
Multi-model: use optimized implementation
artbataev Dec 1, 2025
586221e
Multi-model: implement model removal
artbataev Dec 1, 2025
398e4ed
Clean up
artbataev Dec 1, 2025
b9a4ba2
Clean up
artbataev Dec 1, 2025
cce5b26
Fix triton implementation
artbataev Dec 1, 2025
13af31b
Fix flake8 suggestions
artbataev Dec 1, 2025
985ccb4
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 3, 2025
5cfc624
Support CUDA graphs for RNN-T
artbataev Dec 3, 2025
7a00812
Support CUDA graphs for TDT
artbataev Dec 3, 2025
0d44900
Fix TDT
artbataev Dec 3, 2025
dfdb876
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 8, 2025
bfcf441
Fix vocab size
artbataev Dec 9, 2025
fb20279
Fix edge case
artbataev Dec 9, 2025
dac3ef8
Add test
artbataev Dec 9, 2025
4aa7b9d
Remove OneLogger causing issues
artbataev Dec 9, 2025
71638dc
Clean up logging
artbataev Dec 9, 2025
90e2ed3
Fix model removal
artbataev Dec 9, 2025
3c9b2db
Efficient addition and removal of models
artbataev Dec 10, 2025
103c4af
Fix Triton implementation
artbataev Dec 10, 2025
bc80bf5
Fix CUDA graphs
artbataev Dec 10, 2025
01b361c
Clean up
artbataev Dec 10, 2025
0c62cb5
Bugfix
artbataev Dec 10, 2025
f47ebd5
Bugfix
artbataev Dec 10, 2025
bec685c
Clean up
artbataev Dec 10, 2025
b7e78ef
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 10, 2025
28bc0d3
Add docstrings
artbataev Dec 22, 2025
ca58d79
Fix moving fusion models to the device
artbataev Dec 22, 2025
39d73f6
Fix inference mode issues
artbataev Dec 22, 2025
4cae6dd
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 22, 2025
526e0c8
Revert back One Logger callback
artbataev Dec 22, 2025
bf06b3b
Implement memory cache
artbataev Dec 22, 2025
9acac89
Fix model removal
artbataev Dec 22, 2025
a36f3fa
Clean up
artbataev Dec 22, 2025
adf29aa
Add nemo inference test with boosting ground truth
artbataev Dec 22, 2025
5e0bd80
Add test with `asr_model.transcribe`
artbataev Dec 23, 2025
a4e1582
Clean up
artbataev Dec 23, 2025
cc0c3a1
Clean up
artbataev Dec 23, 2025
fad3cad
Clean up decoders
artbataev Dec 23, 2025
8e9a986
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 23, 2025
1908759
Add unit tests
artbataev Dec 23, 2025
d49b342
Fix copyright
artbataev Dec 23, 2025
5589360
Fix tests
artbataev Dec 25, 2025
60f53aa
Support biasing request in manifest for streaming inference scripts
artbataev Jan 12, 2026
6c3981c
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Jan 12, 2026
e33942a
Use `asr_streaming_infer.py` in functional tests. Add test with per-s…
artbataev Jan 12, 2026
e8d0524
Remove redundant WER calculation
artbataev Jan 13, 2026
5470fc7
Clean up. Add docstring
artbataev Jan 13, 2026
9b479f8
Remove unused import
artbataev Jan 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/cicd-main-speech.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ jobs:
script: L2_Speech_Transcription_Speech_to_Text_Cache_Aware_Infer
- runner: self-hosted-azure
script: L2_Speech_Transcription_Streaming_Inference
- runner: self-hosted-azure
script: L2_Speech_Transcription_Speech_to_Text_Inference_Boost_GT
- runner: self-hosted-azure
script: L2_Speech_Transcription_Speech_to_Text_Transcribe_Boost_GT
- runner: self-hosted-azure
script: L2_Speech_Transcription_Canary_Transcribe_Full_Manifest
- runner: self-hosted-azure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = alloc_conf


import librosa
import lightning.pytorch as pl
import torch
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
GreedyBatchedLabelLoopingComputerBase,
Expand All @@ -95,6 +97,7 @@
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.timers import SimpleTimer


def make_divisible_by(num, factor: int) -> int:
Expand Down Expand Up @@ -145,6 +148,8 @@ class TranscriptionConfig:

# Decoding strategy for RNNT models
decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig)
# Per-utterance biasing with biasing config in the manifest
use_per_stream_biasing: bool = False

timestamps: bool = False # output timestamps

Expand All @@ -154,6 +159,8 @@ class TranscriptionConfig:
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False

calculate_rtfx: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
Expand Down Expand Up @@ -216,6 +223,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
asr_model = asr_model.to(asr_model.device)
asr_model.to(compute_dtype)

use_per_stream_biasing = cfg.use_per_stream_biasing

# Change Decoding Config
with open_dict(cfg.decoding):
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
Expand All @@ -226,6 +235,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
cfg.decoding.greedy.preserve_alignments = False
cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference.
cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only
if use_per_stream_biasing:
cfg.decoding.greedy.enable_per_stream_biasing = use_per_stream_biasing

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
Expand Down Expand Up @@ -289,8 +300,27 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
latency_secs = (context_samples.chunk + context_samples.right) / audio_sample_rate
logging.info(f"Theoretical latency: {latency_secs:.2f} seconds")

biasing_requests: list[BiasingRequestItemConfig | None] | None
if use_per_stream_biasing:
biasing_requests = [
(
BiasingRequestItemConfig(
**OmegaConf.to_container(
OmegaConf.merge(OmegaConf.structured(BiasingRequestItemConfig), record["biasing_request"])
)
)
if "biasing_request" in record
else None
)
for record in records
]
else:
biasing_requests = None

audio_dataset = SimpleAudioDataset(
audio_filenames=[record["audio_filepath"] for record in records], sample_rate=audio_sample_rate
audio_filenames=[record["audio_filepath"] for record in records],
sample_rate=audio_sample_rate,
biasing_requests=biasing_requests,
)
audio_dataloader = DataLoader(
dataset=audio_dataset,
Expand All @@ -302,9 +332,11 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
in_order=True,
)

timer = SimpleTimer()
with torch.no_grad(), torch.inference_mode():
all_hyps = []
audio_data: AudioBatch
timer.start(device=map_location)
for audio_data in tqdm(audio_dataloader):
# get audio
# NB: preprocessor runs on torch.float32, no need to cast dtype here
Expand All @@ -313,8 +345,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
batch_size = audio_batch.shape[0]
device = audio_batch.device

# decode audio by chunks
# add biasing requests to the decoder
if use_per_stream_biasing:
multi_biasing_ids = torch.full([batch_size], fill_value=-1, dtype=torch.long, device=map_location)
if audio_data.biasing_requests is not None:
for batch_i, request in enumerate(audio_data.biasing_requests):
if request is not None:
biasing_model = request.get_model(tokenizer=asr_model.tokenizer)
if biasing_model is not None:
multi_model_id = decoding_computer.biasing_multi_model.add_model(biasing_model)
request.multi_model_id = multi_model_id
multi_biasing_ids[batch_i] = multi_model_id
else:
multi_biasing_ids = None

# decode audio by chunks
current_batched_hyps: BatchedHyps | None = None
state = None
left_sample = 0
Expand Down Expand Up @@ -368,6 +413,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
encoder_context_batch.chunk,
),
prev_batched_state=state,
multi_biasing_ids=multi_biasing_ids,
)
# merge hyps with previous hyps
if current_batched_hyps is None:
Expand All @@ -380,7 +426,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
left_sample = right_sample
right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk

# remove biasing requests from the decoder
if use_per_stream_biasing and audio_data.biasing_requests is not None:
for request in audio_data.biasing_requests:
if request is not None and request.multi_model_id is not None:
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
request.multi_model_id = None
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
timer.stop(device=map_location)

# convert text
for i, hyp in enumerate(all_hyps):
Expand All @@ -399,6 +452,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
)
logging.info(f"Finished writing predictions to {output_filename}!")

if cfg.calculate_rtfx:
durations = [
record["duration"] if "duration" in record else librosa.get_duration(path=record["audio_filepath"])
for record in records
]
rtfx_measurements = sum(durations) / timer.total_sec()
logging.info(f"Model RTFx on the dataset: {rtfx_measurements:.3f}")

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
Expand Down
37 changes: 30 additions & 7 deletions examples/asr/asr_streaming_inference/asr_streaming_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@
"""


from time import time

import hydra
from omegaconf import OmegaConf

from nemo.collections.asr.inference.factory.pipeline_builder import PipelineBuilder
from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
from nemo.collections.asr.inference.utils.manifest_io import calculate_duration, dump_output, get_audio_filepaths
from nemo.collections.asr.inference.utils.pipeline_eval import calculate_pipeline_laal, evaluate_pipeline
from nemo.collections.asr.inference.utils.progressbar import TQDMProgressBar
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
from nemo.utils import logging
from nemo.utils.timers import SimpleTimer

# disable nemo_text_processing logging
try:
Expand Down Expand Up @@ -80,15 +82,36 @@ def main(cfg):
pipeline = PipelineBuilder.build_pipeline(cfg)
progress_bar = TQDMProgressBar()

# Add biasing requests
if manifest:
options = [
ASRRequestOptions(
biasing_cfg=(
BiasingRequestItemConfig(
**OmegaConf.to_container(
OmegaConf.merge(OmegaConf.structured(BiasingRequestItemConfig), record["biasing_request"])
)
)
if "biasing_request" in record
else None
)
)
for record in manifest
]
else:
options = None

# Run the pipeline
start = time()
output = pipeline.run(audio_filepaths, progress_bar=progress_bar)
exec_dur = time() - start
timer = SimpleTimer()
timer.start(pipeline.device)
output = pipeline.run(audio_filepaths, progress_bar=progress_bar, options=options)
timer.stop(pipeline.device)
exec_dur = timer.total_sec()

# Calculate RTFX
# Calculate RTFx
data_dur, durations = calculate_duration(audio_filepaths)
rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf')
logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)")
logging.info(f"RTFx: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)")

# Calculate LAAL
laal = calculate_pipeline_laal(output, durations, manifest, cfg)
Expand Down
7 changes: 6 additions & 1 deletion examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ asr:
fused_batch_size: -1
greedy:
use_cuda_graph_decoder: true
enable_per_stream_biasing: true # Per-stream biasing in decoder
max_symbols: 10
# n-gram LM
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
Expand All @@ -22,7 +23,11 @@ asr:
model_path: null # The path to built '.nemo' boosting tree model
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
key_phrase_items_list: null # The list of context-biasing phrases with custom fields
# in CLI: [{phrase:"word1",lang:en},{phrase:"frase dos",lang:es}]
# in code: [PhraseItem(phrase="word1, lang="en"), PhraseItem(phrase2="frase dos", lang="es")]
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
# used with `key_phrases_file` and `key_phrases_list`
boosting_tree_alpha: 0.0


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ asr:
fused_batch_size: -1
greedy:
use_cuda_graph_decoder: false # Disabled due to issues with decoding
enable_per_stream_biasing: false # Per-stream biasing in decoder
max_symbols: 10
# n-gram LM
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
Expand All @@ -22,7 +23,11 @@ asr:
model_path: null # The path to built '.nemo' boosting tree model
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
key_phrase_items_list: null # The list of context-biasing phrases with custom fields
# in CLI: [{phrase:"word1",lang:en},{phrase:"frase dos",lang:es}]
# in code: [PhraseItem(phrase="word1, lang="en"), PhraseItem(phrase2="frase dos", lang="es")]
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
# used with `key_phrases_file` and `key_phrases_list`
boosting_tree_alpha: 0.0 # Weight of the boosting tree

# ==========================================
Expand Down
5 changes: 2 additions & 3 deletions nemo/collections/asr/inference/factory/pipeline_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.


from typing import Any

import torch
from omegaconf.dictconfig import DictConfig

from nemo.collections.asr.inference.factory.buffered_pipeline_builder import BufferedPipelineBuilder
from nemo.collections.asr.inference.factory.cache_aware_pipeline_builder import CacheAwarePipelineBuilder
from nemo.collections.asr.inference.pipelines.base_pipeline import BasePipeline
from nemo.collections.asr.inference.utils.enums import PipelineType
from nemo.utils import logging

Expand Down Expand Up @@ -54,7 +53,7 @@ def set_log_level(log_level: int) -> None:
logging.setLevel(log_level)

@staticmethod
def build_pipeline(cfg: DictConfig) -> Any:
def build_pipeline(cfg: DictConfig) -> BasePipeline:
"""
Build the pipeline based on the config.
Args:
Expand Down
52 changes: 45 additions & 7 deletions nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import math
from typing import TYPE_CHECKING
import numpy as np

import torch
from omegaconf import DictConfig
Expand All @@ -39,6 +40,7 @@
update_punctuation_and_language_tokens_timestamps,
)
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis as NemoHypothesis
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer
Expand Down Expand Up @@ -520,29 +522,57 @@ def stateful_transcribe_step(
states = [self.get_state(request.stream_id) for request in requests]
partial_hypotheses, rnnt_states = [], []
all_rnnt_states_are_none = True
for state in states:
all_multi_biasing_models_empty = True
multi_biasing_ids = np.full([len(states)], fill_value=-1)
for i, state in enumerate(states):
hyp_state = state.hyp_decoding_state
rnnt_states.append(hyp_state)
if hyp_state is not None:
all_rnnt_states_are_none = False
if state.has_biasing_request():
if state.options.biasing_cfg.multi_model_id is not None:
all_multi_biasing_models_empty = False
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
elif state.options.biasing_cfg.auto_manage_multi_model:
state.options.biasing_cfg.add_to_multi_model(
tokenizer=self.asr_model.tokenizer,
biasing_multi_model=self.decoding_computer.biasing_multi_model,
)
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
all_multi_biasing_models_empty = False
else:
logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping")
if hyp_state is not None or state.has_biasing_request():
partial_hypotheses.append(
NemoHypothesis(score=0.0, y_sequence=torch.zeros([0], dtype=torch.long), dec_state=hyp_state)
NemoHypothesis(
score=0.0,
y_sequence=torch.zeros([0], dtype=torch.long),
dec_state=hyp_state,
biasing_cfg=state.options.biasing_cfg,
)
)
rnnt_states.append(hyp_state)
all_rnnt_states_are_none = False
else:
partial_hypotheses.append(None)
rnnt_states.append(None)

batched_rnnt_states = None
if not all_rnnt_states_are_none:
batched_rnnt_states = self.decoding_computer.merge_to_batched_state(rnnt_states)

if all_multi_biasing_models_empty:
multi_biasing_ids = None
else:
multi_biasing_ids = torch.from_numpy(multi_biasing_ids).to(device=enc_lens_chunk.device)

batched_state = None
if self.tokens_per_right_padding > 0:
with torch.inference_mode(), torch.no_grad():
best_hyp_chunk, alignments, batched_state = self.decoding_computer(
encs.transpose(1, 2), enc_lens_chunk, batched_rnnt_states
encs.transpose(1, 2),
enc_lens_chunk,
batched_rnnt_states,
multi_biasing_ids=multi_biasing_ids,
)

# TODO(@artbataev): remove double-decoding
best_hyp = self.asr_model.decode(encs, enc_lens, partial_hypotheses=partial_hypotheses)
if self.tokens_per_right_padding > 0 and batched_state is not None:
for state, rnnt_state in zip(states, self.decoding_computer.split_batched_state(batched_state)):
Expand All @@ -556,6 +586,14 @@ def stateful_transcribe_step(
curr_state.timestamp_offset += self.tokens_per_frame_float
ready_state_ids.update(ready_states)

for request, state in zip(requests, states):
# only the first request contains biasing options; biasing options for the stream are stored in state
if request.is_last and state.has_biasing_request():
if state.options.biasing_cfg.auto_manage_multi_model:
state.options.biasing_cfg.remove_from_multi_model(
biasing_multi_model=self.decoding_computer.biasing_multi_model
)

def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNTStreamingState]) -> set:
"""
Perform greedy RNNT decoding to get the best hypothesis and update the state.
Expand Down
Loading
Loading