Skip to content

Commit eaa1de5

Browse files
authored
Per-Stream Phrase Boosting in ASR Decoding (Transducers) (#15125)
* Add multi-model for boosting (slow reference `GPUBiasingMultiModelReference` and efficient fast `GPUBiasingMultiModel`), which can be further used with other decoders for per-stream context biasing * Add per-stream (per-utterance) phrase boosting, currently only for greedy label-looping decoding with transducers (RNN-T, TDT) * Enhance `BoostingTreeModelConfig` with `key_phrase_items_list` field to specify key phrases with per-phrase options (currently - allows to specify per-phrase lang to use with aggregate tokenizers). Signed-off-by: Vladimir Bataev <vbataev@nvidia.com> --------- Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
1 parent 0233dbd commit eaa1de5

30 files changed

+2053
-313
lines changed

.github/workflows/cicd-main-speech.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ jobs:
131131
script: L2_Speech_Transcription_Speech_to_Text_Cache_Aware_Infer
132132
- runner: self-hosted-azure
133133
script: L2_Speech_Transcription_Streaming_Inference
134+
- runner: self-hosted-azure
135+
script: L2_Speech_Transcription_Speech_to_Text_Inference_Boost_GT
136+
- runner: self-hosted-azure
137+
script: L2_Speech_Transcription_Speech_to_Text_Transcribe_Boost_GT
134138
- runner: self-hosted-azure
135139
script: L2_Speech_Transcription_Canary_Transcribe_Full_Manifest
136140
- runner: self-hosted-azure

examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,15 @@
6565
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = alloc_conf
6666

6767

68+
import librosa
6869
import lightning.pytorch as pl
6970
import torch
7071
from omegaconf import OmegaConf, open_dict
7172
from torch.utils.data import DataLoader
7273
from tqdm.auto import tqdm
7374

7475
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
76+
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
7577
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
7678
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
7779
GreedyBatchedLabelLoopingComputerBase,
@@ -95,6 +97,7 @@
9597
)
9698
from nemo.core.config import hydra_runner
9799
from nemo.utils import logging
100+
from nemo.utils.timers import SimpleTimer
98101

99102

100103
def make_divisible_by(num, factor: int) -> int:
@@ -145,6 +148,8 @@ class TranscriptionConfig:
145148

146149
# Decoding strategy for RNNT models
147150
decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig)
151+
# Per-utterance biasing with biasing config in the manifest
152+
use_per_stream_biasing: bool = False
148153

149154
timestamps: bool = False # output timestamps
150155

@@ -154,6 +159,8 @@ class TranscriptionConfig:
154159
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
155160
use_cer: bool = False
156161

162+
calculate_rtfx: bool = False
163+
157164

158165
@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
159166
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
@@ -216,6 +223,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
216223
asr_model = asr_model.to(asr_model.device)
217224
asr_model.to(compute_dtype)
218225

226+
use_per_stream_biasing = cfg.use_per_stream_biasing
227+
219228
# Change Decoding Config
220229
with open_dict(cfg.decoding):
221230
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
@@ -226,6 +235,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
226235
cfg.decoding.greedy.preserve_alignments = False
227236
cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference.
228237
cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only
238+
if use_per_stream_biasing:
239+
cfg.decoding.greedy.enable_per_stream_biasing = use_per_stream_biasing
229240

230241
# Setup decoding strategy
231242
if hasattr(asr_model, 'change_decoding_strategy'):
@@ -289,8 +300,27 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
289300
latency_secs = (context_samples.chunk + context_samples.right) / audio_sample_rate
290301
logging.info(f"Theoretical latency: {latency_secs:.2f} seconds")
291302

303+
biasing_requests: list[BiasingRequestItemConfig | None] | None
304+
if use_per_stream_biasing:
305+
biasing_requests = [
306+
(
307+
BiasingRequestItemConfig(
308+
**OmegaConf.to_container(
309+
OmegaConf.merge(OmegaConf.structured(BiasingRequestItemConfig), record["biasing_request"])
310+
)
311+
)
312+
if "biasing_request" in record
313+
else None
314+
)
315+
for record in records
316+
]
317+
else:
318+
biasing_requests = None
319+
292320
audio_dataset = SimpleAudioDataset(
293-
audio_filenames=[record["audio_filepath"] for record in records], sample_rate=audio_sample_rate
321+
audio_filenames=[record["audio_filepath"] for record in records],
322+
sample_rate=audio_sample_rate,
323+
biasing_requests=biasing_requests,
294324
)
295325
audio_dataloader = DataLoader(
296326
dataset=audio_dataset,
@@ -302,9 +332,11 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
302332
in_order=True,
303333
)
304334

335+
timer = SimpleTimer()
305336
with torch.no_grad(), torch.inference_mode():
306337
all_hyps = []
307338
audio_data: AudioBatch
339+
timer.start(device=map_location)
308340
for audio_data in tqdm(audio_dataloader):
309341
# get audio
310342
# NB: preprocessor runs on torch.float32, no need to cast dtype here
@@ -313,8 +345,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
313345
batch_size = audio_batch.shape[0]
314346
device = audio_batch.device
315347

316-
# decode audio by chunks
348+
# add biasing requests to the decoder
349+
if use_per_stream_biasing:
350+
multi_biasing_ids = torch.full([batch_size], fill_value=-1, dtype=torch.long, device=map_location)
351+
if audio_data.biasing_requests is not None:
352+
for batch_i, request in enumerate(audio_data.biasing_requests):
353+
if request is not None:
354+
biasing_model = request.get_model(tokenizer=asr_model.tokenizer)
355+
if biasing_model is not None:
356+
multi_model_id = decoding_computer.biasing_multi_model.add_model(biasing_model)
357+
request.multi_model_id = multi_model_id
358+
multi_biasing_ids[batch_i] = multi_model_id
359+
else:
360+
multi_biasing_ids = None
317361

362+
# decode audio by chunks
318363
current_batched_hyps: BatchedHyps | None = None
319364
state = None
320365
left_sample = 0
@@ -368,6 +413,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
368413
encoder_context_batch.chunk,
369414
),
370415
prev_batched_state=state,
416+
multi_biasing_ids=multi_biasing_ids,
371417
)
372418
# merge hyps with previous hyps
373419
if current_batched_hyps is None:
@@ -380,7 +426,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
380426
left_sample = right_sample
381427
right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk
382428

429+
# remove biasing requests from the decoder
430+
if use_per_stream_biasing and audio_data.biasing_requests is not None:
431+
for request in audio_data.biasing_requests:
432+
if request is not None and request.multi_model_id is not None:
433+
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
434+
request.multi_model_id = None
383435
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
436+
timer.stop(device=map_location)
384437

385438
# convert text
386439
for i, hyp in enumerate(all_hyps):
@@ -399,6 +452,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
399452
)
400453
logging.info(f"Finished writing predictions to {output_filename}!")
401454

455+
if cfg.calculate_rtfx:
456+
durations = [
457+
record["duration"] if "duration" in record else librosa.get_duration(path=record["audio_filepath"])
458+
for record in records
459+
]
460+
rtfx_measurements = sum(durations) / timer.total_sec()
461+
logging.info(f"Model RTFx on the dataset: {rtfx_measurements:.3f}")
462+
402463
if cfg.calculate_wer:
403464
output_manifest_w_wer, total_res, _ = cal_write_wer(
404465
pred_manifest=output_filename,

examples/asr/asr_streaming_inference/asr_streaming_infer.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@
4242
"""
4343

4444

45-
from time import time
46-
4745
import hydra
46+
from omegaconf import OmegaConf
4847

4948
from nemo.collections.asr.inference.factory.pipeline_builder import PipelineBuilder
49+
from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
5050
from nemo.collections.asr.inference.utils.manifest_io import calculate_duration, dump_output, get_audio_filepaths
5151
from nemo.collections.asr.inference.utils.pipeline_eval import calculate_pipeline_laal, evaluate_pipeline
5252
from nemo.collections.asr.inference.utils.progressbar import TQDMProgressBar
53+
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
5354
from nemo.utils import logging
55+
from nemo.utils.timers import SimpleTimer
5456

5557
# disable nemo_text_processing logging
5658
try:
@@ -80,15 +82,36 @@ def main(cfg):
8082
pipeline = PipelineBuilder.build_pipeline(cfg)
8183
progress_bar = TQDMProgressBar()
8284

85+
# Add biasing requests
86+
if manifest:
87+
options = [
88+
ASRRequestOptions(
89+
biasing_cfg=(
90+
BiasingRequestItemConfig(
91+
**OmegaConf.to_container(
92+
OmegaConf.merge(OmegaConf.structured(BiasingRequestItemConfig), record["biasing_request"])
93+
)
94+
)
95+
if "biasing_request" in record
96+
else None
97+
)
98+
)
99+
for record in manifest
100+
]
101+
else:
102+
options = None
103+
83104
# Run the pipeline
84-
start = time()
85-
output = pipeline.run(audio_filepaths, progress_bar=progress_bar)
86-
exec_dur = time() - start
105+
timer = SimpleTimer()
106+
timer.start(pipeline.device)
107+
output = pipeline.run(audio_filepaths, progress_bar=progress_bar, options=options)
108+
timer.stop(pipeline.device)
109+
exec_dur = timer.total_sec()
87110

88-
# Calculate RTFX
111+
# Calculate RTFx
89112
data_dur, durations = calculate_duration(audio_filepaths)
90113
rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf')
91-
logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)")
114+
logging.info(f"RTFx: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)")
92115

93116
# Calculate LAAL
94117
laal = calculate_pipeline_laal(output, durations, manifest, cfg)

examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ asr:
1313
fused_batch_size: -1
1414
greedy:
1515
use_cuda_graph_decoder: true
16+
enable_per_stream_biasing: true # Per-stream biasing in decoder
1617
max_symbols: 10
1718
# n-gram LM
1819
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
@@ -22,7 +23,11 @@ asr:
2223
model_path: null # The path to built '.nemo' boosting tree model
2324
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
2425
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
25-
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
26+
key_phrase_items_list: null # The list of context-biasing phrases with custom fields
27+
# in CLI: [{phrase:"word1",lang:en},{phrase:"frase dos",lang:es}]
28+
# in code: [PhraseItem(phrase="word1, lang="en"), PhraseItem(phrase2="frase dos", lang="es")]
29+
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
30+
# used with `key_phrases_file` and `key_phrases_list`
2631
boosting_tree_alpha: 0.0
2732

2833

examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ asr:
1313
fused_batch_size: -1
1414
greedy:
1515
use_cuda_graph_decoder: false # Disabled due to issues with decoding
16+
enable_per_stream_biasing: false # Per-stream biasing in decoder
1617
max_symbols: 10
1718
# n-gram LM
1819
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
@@ -22,7 +23,11 @@ asr:
2223
model_path: null # The path to built '.nemo' boosting tree model
2324
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
2425
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
25-
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
26+
key_phrase_items_list: null # The list of context-biasing phrases with custom fields
27+
# in CLI: [{phrase:"word1",lang:en},{phrase:"frase dos",lang:es}]
28+
# in code: [PhraseItem(phrase="word1, lang="en"), PhraseItem(phrase2="frase dos", lang="es")]
29+
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
30+
# used with `key_phrases_file` and `key_phrases_list`
2631
boosting_tree_alpha: 0.0 # Weight of the boosting tree
2732

2833
# ==========================================

nemo/collections/asr/inference/factory/pipeline_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Any
17-
1816
import torch
1917
from omegaconf.dictconfig import DictConfig
2018

2119
from nemo.collections.asr.inference.factory.buffered_pipeline_builder import BufferedPipelineBuilder
2220
from nemo.collections.asr.inference.factory.cache_aware_pipeline_builder import CacheAwarePipelineBuilder
21+
from nemo.collections.asr.inference.pipelines.base_pipeline import BasePipeline
2322
from nemo.collections.asr.inference.utils.enums import PipelineType
2423
from nemo.utils import logging
2524

@@ -54,7 +53,7 @@ def set_log_level(log_level: int) -> None:
5453
logging.setLevel(log_level)
5554

5655
@staticmethod
57-
def build_pipeline(cfg: DictConfig) -> Any:
56+
def build_pipeline(cfg: DictConfig) -> BasePipeline:
5857
"""
5958
Build the pipeline based on the config.
6059
Args:

nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import math
1818
from typing import TYPE_CHECKING
19+
import numpy as np
1920

2021
import torch
2122
from omegaconf import DictConfig
@@ -39,6 +40,7 @@
3940
update_punctuation_and_language_tokens_timestamps,
4041
)
4142
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis as NemoHypothesis
43+
from nemo.utils import logging
4244

4345
if TYPE_CHECKING:
4446
from nemo.collections.asr.inference.itn.inverse_normalizer import AlignmentPreservingInverseNormalizer
@@ -520,29 +522,57 @@ def stateful_transcribe_step(
520522
states = [self.get_state(request.stream_id) for request in requests]
521523
partial_hypotheses, rnnt_states = [], []
522524
all_rnnt_states_are_none = True
523-
for state in states:
525+
all_multi_biasing_models_empty = True
526+
multi_biasing_ids = np.full([len(states)], fill_value=-1)
527+
for i, state in enumerate(states):
524528
hyp_state = state.hyp_decoding_state
529+
rnnt_states.append(hyp_state)
525530
if hyp_state is not None:
531+
all_rnnt_states_are_none = False
532+
if state.has_biasing_request():
533+
if state.options.biasing_cfg.multi_model_id is not None:
534+
all_multi_biasing_models_empty = False
535+
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
536+
elif state.options.biasing_cfg.auto_manage_multi_model:
537+
state.options.biasing_cfg.add_to_multi_model(
538+
tokenizer=self.asr_model.tokenizer,
539+
biasing_multi_model=self.decoding_computer.biasing_multi_model,
540+
)
541+
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
542+
all_multi_biasing_models_empty = False
543+
else:
544+
logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping")
545+
if hyp_state is not None or state.has_biasing_request():
526546
partial_hypotheses.append(
527-
NemoHypothesis(score=0.0, y_sequence=torch.zeros([0], dtype=torch.long), dec_state=hyp_state)
547+
NemoHypothesis(
548+
score=0.0,
549+
y_sequence=torch.zeros([0], dtype=torch.long),
550+
dec_state=hyp_state,
551+
biasing_cfg=state.options.biasing_cfg,
552+
)
528553
)
529-
rnnt_states.append(hyp_state)
530-
all_rnnt_states_are_none = False
531554
else:
532555
partial_hypotheses.append(None)
533-
rnnt_states.append(None)
534556

535557
batched_rnnt_states = None
536558
if not all_rnnt_states_are_none:
537559
batched_rnnt_states = self.decoding_computer.merge_to_batched_state(rnnt_states)
538560

561+
if all_multi_biasing_models_empty:
562+
multi_biasing_ids = None
563+
else:
564+
multi_biasing_ids = torch.from_numpy(multi_biasing_ids).to(device=enc_lens_chunk.device)
565+
539566
batched_state = None
540567
if self.tokens_per_right_padding > 0:
541568
with torch.inference_mode(), torch.no_grad():
542569
best_hyp_chunk, alignments, batched_state = self.decoding_computer(
543-
encs.transpose(1, 2), enc_lens_chunk, batched_rnnt_states
570+
encs.transpose(1, 2),
571+
enc_lens_chunk,
572+
batched_rnnt_states,
573+
multi_biasing_ids=multi_biasing_ids,
544574
)
545-
575+
# TODO(@artbataev): remove double-decoding
546576
best_hyp = self.asr_model.decode(encs, enc_lens, partial_hypotheses=partial_hypotheses)
547577
if self.tokens_per_right_padding > 0 and batched_state is not None:
548578
for state, rnnt_state in zip(states, self.decoding_computer.split_batched_state(batched_state)):
@@ -556,6 +586,14 @@ def stateful_transcribe_step(
556586
curr_state.timestamp_offset += self.tokens_per_frame_float
557587
ready_state_ids.update(ready_states)
558588

589+
for request, state in zip(requests, states):
590+
# only the first request contains biasing options; biasing options for the stream are stored in state
591+
if request.is_last and state.has_biasing_request():
592+
if state.options.biasing_cfg.auto_manage_multi_model:
593+
state.options.biasing_cfg.remove_from_multi_model(
594+
biasing_multi_model=self.decoding_computer.biasing_multi_model
595+
)
596+
559597
def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNTStreamingState]) -> set:
560598
"""
561599
Perform greedy RNNT decoding to get the best hypothesis and update the state.

0 commit comments

Comments
 (0)