Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
b4c33dd
Fused biasing ids in interface
artbataev Nov 17, 2025
0e8f25d
Bugfix
artbataev Nov 17, 2025
066d5d3
Fused model: non-batched reference implementation
artbataev Nov 17, 2025
9d65743
Change api
artbataev Nov 19, 2025
28aca3d
Temporary bugfix
artbataev Nov 20, 2025
fe42d8a
Allow partial hypothesis without dec_state
artbataev Nov 20, 2025
ca2cf4b
Multi-model biasing: implement with RNN-T
artbataev Nov 21, 2025
a290f32
Multi-model biasing: implement for TDT, fix RNN-T
artbataev Nov 21, 2025
e288228
Add and remove boosting model from multi-model: move logic to Hypothesis
artbataev Nov 22, 2025
8d36596
Clean up per-hyp biasing config
artbataev Nov 26, 2025
b8df95e
Auto manage biasing requests in transcribe
artbataev Nov 26, 2025
e74f7b1
Disable CUDA graphs with per-stream biasing
artbataev Nov 26, 2025
558dba2
Support specifying lang per phrase
artbataev Nov 26, 2025
972fc29
Fix customization options
artbataev Nov 26, 2025
352027f
Fix streaming decoding
artbataev Nov 26, 2025
e4e7fc3
Fix streaming decoding
artbataev Nov 26, 2025
ac88aa3
Fix type
artbataev Nov 26, 2025
768e51a
Implement custom requests in nemo inference
artbataev Nov 26, 2025
e24fc99
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Nov 26, 2025
7209d63
Use biasing options from state instead of request
artbataev Nov 26, 2025
736f942
Fix f-strings
artbataev Nov 26, 2025
e5f9c71
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Nov 28, 2025
d8649b7
Expose parameter: enable_per_stream_biasing
artbataev Nov 28, 2025
285ac29
key_phrase_items_list: use parameter in config
artbataev Nov 28, 2025
59139ef
More documentation
artbataev Nov 28, 2025
3680f07
Clean up biasing_cfg management
artbataev Nov 28, 2025
8aea8ad
Specify todo
artbataev Nov 28, 2025
2b20d8e
Clean up
artbataev Nov 28, 2025
4c55a79
Stubs for optimized model
artbataev Nov 28, 2025
5875202
Multi-model: implement adding model
artbataev Dec 1, 2025
1c04627
Multi-model: stubs for advance
artbataev Dec 1, 2025
6840f0f
Multi-model: stubs for advance using Triton
artbataev Dec 1, 2025
494c4c1
Multi-model: implement Triton kernel
artbataev Dec 1, 2025
db84e03
Multi-model: implement advance in Pytorch
artbataev Dec 1, 2025
225a644
Multi-model: use optimized implementation
artbataev Dec 1, 2025
586221e
Multi-model: implement model removal
artbataev Dec 1, 2025
398e4ed
Clean up
artbataev Dec 1, 2025
b9a4ba2
Clean up
artbataev Dec 1, 2025
cce5b26
Fix triton implementation
artbataev Dec 1, 2025
13af31b
Fix flake8 suggestions
artbataev Dec 1, 2025
985ccb4
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 3, 2025
5cfc624
Support CUDA graphs for RNN-T
artbataev Dec 3, 2025
7a00812
Support CUDA graphs for TDT
artbataev Dec 3, 2025
0d44900
Fix TDT
artbataev Dec 3, 2025
dfdb876
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 8, 2025
bfcf441
Fix vocab size
artbataev Dec 9, 2025
fb20279
Fix edge case
artbataev Dec 9, 2025
dac3ef8
Add test
artbataev Dec 9, 2025
4aa7b9d
Remove OneLogger causing issues
artbataev Dec 9, 2025
71638dc
Clean up logging
artbataev Dec 9, 2025
90e2ed3
Fix model removal
artbataev Dec 9, 2025
3c9b2db
Efficient addition and removal of models
artbataev Dec 10, 2025
103c4af
Fix Triton implementation
artbataev Dec 10, 2025
bc80bf5
Fix CUDA graphs
artbataev Dec 10, 2025
01b361c
Clean up
artbataev Dec 10, 2025
0c62cb5
Bugfix
artbataev Dec 10, 2025
f47ebd5
Bugfix
artbataev Dec 10, 2025
bec685c
Clean up
artbataev Dec 10, 2025
b7e78ef
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 10, 2025
28bc0d3
Add docstrings
artbataev Dec 22, 2025
ca58d79
Fix moving fusion models to the device
artbataev Dec 22, 2025
39d73f6
Fix inference mode issues
artbataev Dec 22, 2025
4cae6dd
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 22, 2025
526e0c8
Revert back One Logger callback
artbataev Dec 22, 2025
bf06b3b
Implement memory cache
artbataev Dec 22, 2025
9acac89
Fix model removal
artbataev Dec 22, 2025
a36f3fa
Clean up
artbataev Dec 22, 2025
adf29aa
Add nemo inference test with boosting ground truth
artbataev Dec 22, 2025
5e0bd80
Add test with `asr_model.transcribe`
artbataev Dec 23, 2025
a4e1582
Clean up
artbataev Dec 23, 2025
cc0c3a1
Clean up
artbataev Dec 23, 2025
fad3cad
Clean up decoders
artbataev Dec 23, 2025
8e9a986
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Dec 23, 2025
1908759
Add unit tests
artbataev Dec 23, 2025
d49b342
Fix copyright
artbataev Dec 23, 2025
5589360
Fix tests
artbataev Dec 25, 2025
60f53aa
Support biasing request in manifest for streaming inference scripts
artbataev Jan 12, 2026
6c3981c
Merge branch 'main' into vbataev/multi_biasing_models
artbataev Jan 12, 2026
e33942a
Use `asr_streaming_infer.py` in functional tests. Add test with per-s…
artbataev Jan 12, 2026
e8d0524
Remove redundant WER calculation
artbataev Jan 13, 2026
5470fc7
Clean up. Add docstring
artbataev Jan 13, 2026
9b479f8
Remove unused import
artbataev Jan 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions nemo/collections/asr/inference/factory/pipeline_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.


from typing import Any

import torch
from omegaconf.dictconfig import DictConfig

from nemo.collections.asr.inference.factory.buffered_pipeline_builder import BufferedPipelineBuilder
from nemo.collections.asr.inference.factory.cache_aware_pipeline_builder import CacheAwarePipelineBuilder
from nemo.collections.asr.inference.pipelines.base_pipeline import BasePipeline
from nemo.collections.asr.inference.utils.enums import PipelineType
from nemo.utils import logging

Expand Down Expand Up @@ -54,7 +53,7 @@ def set_log_level(log_level: int) -> None:
logging.setLevel(log_level)

@staticmethod
def build_pipeline(cfg: DictConfig) -> Any:
def build_pipeline(cfg: DictConfig) -> BasePipeline:
"""
Build the pipeline based on the config.
Args:
Expand Down
52 changes: 45 additions & 7 deletions nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import math
from typing import TYPE_CHECKING
import numpy as np

import torch
from omegaconf import DictConfig
Expand All @@ -39,6 +40,7 @@
update_punctuation_and_language_tokens_timestamps,
)
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis as NemoHypothesis
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer
Expand Down Expand Up @@ -463,29 +465,58 @@ def stateful_transcribe_step(
states = [self.get_state(request.stream_id) for request in requests]
partial_hypotheses, rnnt_states = [], []
all_rnnt_states_are_none = True
for state in states:
all_multi_biasing_models_empty = True
multi_biasing_ids = np.full([len(states)], fill_value=-1)
for i, state in enumerate(states):
hyp_state = state.hyp_decoding_state
rnnt_states.append(hyp_state)
if hyp_state is not None:
all_rnnt_states_are_none = False
if state.options is not None and state.options.has_biasing_request():
if state.options.biasing_cfg.multi_model_id is not None:
all_multi_biasing_models_empty = False
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
elif state.options.biasing_cfg.auto_manage_multi_model:
state.options.biasing_cfg.add_to_multi_model(
tokenizer=self.asr_model.tokenizer,
biasing_multi_model=self.decoding_computer.biasing_multi_model,
)
assert state.options.biasing_cfg.multi_model_id is not None
all_multi_biasing_models_empty = False
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
else:
logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping")
if hyp_state is not None or (state.options is not None and state.options.has_biasing_request()):
partial_hypotheses.append(
NemoHypothesis(score=0.0, y_sequence=torch.zeros([0], dtype=torch.long), dec_state=hyp_state)
NemoHypothesis(
score=0.0,
y_sequence=torch.zeros([0], dtype=torch.long),
dec_state=hyp_state,
biasing_cfg=state.options.biasing_cfg,
)
)
rnnt_states.append(hyp_state)
all_rnnt_states_are_none = False
else:
partial_hypotheses.append(None)
rnnt_states.append(None)

batched_rnnt_states = None
if not all_rnnt_states_are_none:
batched_rnnt_states = self.decoding_computer.merge_to_batched_state(rnnt_states)

if all_multi_biasing_models_empty:
multi_biasing_ids = None
else:
multi_biasing_ids = torch.from_numpy(multi_biasing_ids).to(device=enc_lens_chunk.device)

batched_state = None
if self.tokens_per_right_padding > 0:
with torch.inference_mode(), torch.no_grad():
best_hyp_chunk, alignments, batched_state = self.decoding_computer(
encs.transpose(1, 2), enc_lens_chunk, batched_rnnt_states
encs.transpose(1, 2),
enc_lens_chunk,
batched_rnnt_states,
multi_biasing_ids=multi_biasing_ids,
)

# TODO: remove double-decoding
best_hyp = self.asr_model.decode(encs, enc_lens, partial_hypotheses=partial_hypotheses)
if self.tokens_per_right_padding > 0 and batched_state is not None:
for state, rnnt_state in zip(states, self.decoding_computer.split_batched_state(batched_state)):
Expand All @@ -499,6 +530,13 @@ def stateful_transcribe_step(
curr_state.timestamp_offset += self.tokens_per_frame_float
ready_state_ids.update(ready_states)

for request, state in zip(requests, states):
# only the first request contains biasing options; biasing options for the stream are stored in state
if request.is_last and (state.options is not None and state.options.has_biasing_request()):
state.options.biasing_cfg.remove_from_multi_model(
biasing_multi_model=self.decoding_computer.biasing_multi_model
)

def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNTStreamingState]) -> set:
"""
Perform greedy RNNT decoding to get the best hypothesis and update the state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.


from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TypeAlias
from nemo.collections.asr.inference.utils.enums import ASROutputGranularity
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig


@dataclass(slots=True)
Expand All @@ -29,6 +30,7 @@ class ASRRequestOptions:
enable_pnc: bool = None
stop_history_eou: int = None
asr_output_granularity: ASROutputGranularity | str = None
biasing_cfg: BiasingRequestItemConfig = field(default_factory=BiasingRequestItemConfig)

def __post_init__(self) -> None:
"""
Expand Down Expand Up @@ -76,7 +78,11 @@ def augment_with_defaults(
asr_output_granularity=(
default_asr_output_granularity if self.asr_output_granularity is None else self.asr_output_granularity
),
biasing_cfg=self.biasing_cfg,
)

def has_biasing_request(self):
return not self.biasing_cfg.is_empty()


RequestOptions: TypeAlias = ASRRequestOptions
2 changes: 2 additions & 0 deletions nemo/collections/asr/inference/streaming/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class StreamingState:
Generic state for the streaming ASR pipeline
"""

options: RequestOptions | None

def __init__(self):
"""
Initialize the StreamingState
Expand Down
36 changes: 36 additions & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,42 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
super()._transcribe_on_begin(audio=audio, trcfg=trcfg)
# add biasing requests to the decoding computer
try:
biasing_multi_model = self.decoding.decoding.decoding_computer.biasing_multi_model
except AttributeError:
biasing_multi_model = None
if biasing_multi_model is not None and trcfg.partial_hypothesis:
for partial_hyp in trcfg.partial_hypothesis:
if (
isinstance(partial_hyp, Hypothesis)
and partial_hyp.has_biasing_request()
and partial_hyp.biasing_cfg.auto_manage_multi_model
and partial_hyp.biasing_cfg.multi_model_id is None
):
partial_hyp.biasing_cfg.add_to_multi_model(
tokenizer=self.tokenizer, biasing_multi_model=biasing_multi_model
)

def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg=trcfg)
try:
biasing_multi_model = self.decoding.decoding.decoding_computer.biasing_multi_model
except AttributeError:
biasing_multi_model = None

# remove biasing requests from the decoding computer
if biasing_multi_model is not None and trcfg.partial_hypothesis:
for partial_hyp in trcfg.partial_hypothesis:
if (
isinstance(partial_hyp, Hypothesis)
and partial_hyp.has_biasing_request()
and partial_hyp.biasing_cfg.auto_manage_multi_model
):
partial_hyp.biasing_cfg.remove_from_multi_model(biasing_multi_model=biasing_multi_model)

def on_after_backward(self):
super().on_after_backward()
if self._optim_variational_noise_std > 0 and self.global_step >= self._optim_variational_noise_start:
Expand Down
Loading
Loading