Skip to content

Commit 1ef5781

Browse files
authored
Transducer Decoding: Move fusion models to the base class (#15322)
* Move fusion models to the base class Signed-off-by: Vladimir Bataev <vbataev@nvidia.com> --------- Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
1 parent c6df9b4 commit 1ef5781

File tree

3 files changed

+94
-134
lines changed

3 files changed

+94
-134
lines changed

nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class CudaGraphsMode(PrettyStrEnum):
8989
max_symbols: Optional[int]
9090
allow_cuda_graphs: bool
9191
biasing_multi_model: GPUBiasingMultiModelBase | None
92+
fusion_models: list[NGramGPULanguageModel]
93+
fusion_models_alpha: list[float]
9294

9395
def force_cuda_graphs_mode(self, mode: Optional[str | CudaGraphsMode]):
9496
"""
@@ -138,6 +140,64 @@ def disable_cuda_graphs(self) -> bool:
138140
self.reset_cuda_graphs_state()
139141
return True
140142

143+
# fusion models-related methods
144+
@property
145+
def per_stream_biasing_enabled(self):
146+
return self.biasing_multi_model is not None
147+
148+
def _all_fusion_models(
149+
self, with_multi_model: bool = True
150+
) -> list[NGramGPULanguageModel | GPUBiasingMultiModelBase]:
151+
if with_multi_model and self.per_stream_biasing_enabled:
152+
return self.fusion_models + [self.biasing_multi_model]
153+
return self.fusion_models
154+
155+
def _all_fusion_models_with_params(self, with_multi_model: bool = True) -> list[FusionModelWithParams]:
156+
models_with_params = [
157+
FusionModelWithParams(model=model, alpha=alpha, is_multi_model=False)
158+
for model, alpha in zip(self.fusion_models, self.fusion_models_alpha)
159+
]
160+
if with_multi_model and self.per_stream_biasing_enabled:
161+
models_with_params.append(
162+
FusionModelWithParams(model=self.biasing_multi_model, alpha=None, is_multi_model=True)
163+
)
164+
return models_with_params
165+
166+
def has_fusion_models(self, with_multi_model: bool = True) -> bool:
167+
if len(self.fusion_models) > 0:
168+
return True
169+
return with_multi_model and self.per_stream_biasing_enabled
170+
171+
def _move_fusion_models_to_device(self, device: torch.device):
172+
"""
173+
Move all fusion models to device.
174+
We need to do this since `self` is not nn.Module instance, but owns fusion models (nn.Module instances).
175+
"""
176+
with torch.inference_mode(mode=False):
177+
# NB: we avoid inference mode since otherwise all model params/buffers will be inference tensors,
178+
# which will make further inplace manipulations impossible
179+
# (e.g., `remove_model` for multi-model will throw errors)
180+
for fusion_model in self._all_fusion_models():
181+
fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually
182+
183+
def advance_fusion_models(
184+
self, fusion_states_list: list[torch.Tensor], multi_biasing_ids: torch.Tensor | None, float_dtype: torch.dtype
185+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
186+
fusion_states_candidates_list = []
187+
fusion_scores_list = []
188+
for fusion_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
189+
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
190+
states=fusion_states_list[fusion_idx],
191+
**({"model_ids": multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
192+
)
193+
fusion_scores = fusion_scores.to(dtype=float_dtype)
194+
if not fusion_model_with_params.is_multi_model:
195+
fusion_scores *= fusion_model_with_params.alpha
196+
# save fusion scores and states candidates
197+
fusion_scores_list.append(fusion_scores)
198+
fusion_states_candidates_list.append(fusion_states_candidates)
199+
return fusion_scores_list, fusion_states_candidates_list
200+
141201
@abstractmethod
142202
def torch_impl(
143203
self,

nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py

Lines changed: 17 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,10 @@
2020
import torch.nn.functional as F
2121
from omegaconf import DictConfig
2222

23-
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import (
24-
GPUBiasingMultiModel,
25-
GPUBiasingMultiModelBase,
26-
)
23+
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import GPUBiasingMultiModel
2724
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
2825
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
2926
BatchedLabelLoopingState,
30-
FusionModelWithParams,
3127
GreedyBatchedLabelLoopingComputerBase,
3228
LabelLoopingStateItem,
3329
SeparateGraphsLabelLooping,
@@ -265,33 +261,6 @@ def __init__(
265261
self.cuda_graphs_allow_fallback = True
266262
self.maybe_enable_cuda_graphs()
267263

268-
@property
269-
def per_stream_biasing_enabled(self):
270-
return self.biasing_multi_model is not None
271-
272-
def _all_fusion_models(
273-
self, with_multi_model: bool = True
274-
) -> list[NGramGPULanguageModel | GPUBiasingMultiModelBase]:
275-
if with_multi_model and self.per_stream_biasing_enabled:
276-
return self.fusion_models + [self.biasing_multi_model]
277-
return self.fusion_models
278-
279-
def _all_fusion_models_with_params(self, with_multi_model: bool = True) -> list[FusionModelWithParams]:
280-
models_with_params = [
281-
FusionModelWithParams(model=model, alpha=alpha, is_multi_model=False)
282-
for model, alpha in zip(self.fusion_models, self.fusion_models_alpha)
283-
]
284-
if with_multi_model and self.per_stream_biasing_enabled:
285-
models_with_params.append(
286-
FusionModelWithParams(model=self.biasing_multi_model, alpha=None, is_multi_model=True)
287-
)
288-
return models_with_params
289-
290-
def has_fusion_models(self, with_multi_model: bool = True) -> bool:
291-
if len(self.fusion_models) > 0:
292-
return True
293-
return with_multi_model and self.per_stream_biasing_enabled
294-
295264
def reset_cuda_graphs_state(self):
296265
"""Reset state to release memory (for CUDA graphs implementations)"""
297266
self.state = None
@@ -306,18 +275,6 @@ def _get_frame_confidence(self, logits: torch.Tensor) -> Optional[torch.Tensor]:
306275
else None
307276
)
308277

309-
def _move_fusion_models_to_device(self, device: torch.device):
310-
"""
311-
Move all fusion models to device.
312-
We need to do this since `self` is not nn.Module instance, but owns fusion models (nn.Module instances).
313-
"""
314-
with torch.inference_mode(mode=False):
315-
# NB: we avoid inference mode since otherwise all model params/buffers will be inference tensors,
316-
# which will make further inplace manipulations impossible
317-
# (e.g., `remove_model` for multi-model will throw errors)
318-
for fusion_model in self._all_fusion_models():
319-
fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually
320-
321278
def torch_impl(
322279
self,
323280
encoder_output: torch.Tensor,
@@ -416,21 +373,14 @@ def torch_impl(
416373
scores, labels = logits.max(-1)
417374

418375
if self.has_fusion_models():
419-
fusion_scores_list, fusion_states_candidates_list = [], []
376+
fusion_scores_list, fusion_states_candidates_list = self.advance_fusion_models(
377+
fusion_states_list=fusion_states_list,
378+
multi_biasing_ids=multi_biasing_ids,
379+
float_dtype=float_dtype,
380+
)
420381
logits_with_fusion = logits.clone()
421-
for fusion_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
422-
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
423-
states=fusion_states_list[fusion_idx],
424-
**({"model_ids": multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
425-
)
426-
fusion_scores = fusion_scores.to(dtype=float_dtype)
427-
if not fusion_model_with_params.is_multi_model:
428-
fusion_scores *= fusion_model_with_params.alpha
429-
# combine logits with fusion model without blank
382+
for fusion_scores in fusion_scores_list:
430383
logits_with_fusion[:, :-1] += fusion_scores
431-
# save fusion scores and states candidates
432-
fusion_scores_list.append(fusion_scores)
433-
fusion_states_candidates_list.append(fusion_states_candidates)
434384

435385
# get max scores and labels without blank
436386
fusion_scores_max, fusion_labels_max = logits_with_fusion[:, :-1].max(dim=-1)
@@ -478,7 +428,6 @@ def torch_impl(
478428
if self.has_fusion_models():
479429
logits_with_fusion = logits.clone()
480430
for fusion_scores in fusion_scores_list:
481-
# combined scores with fusion model - without blank
482431
logits_with_fusion[:, :-1] += fusion_scores
483432
# get max scores and labels without blank
484433
more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, :-1].max(dim=-1)
@@ -1132,18 +1081,19 @@ def _before_inner_loop_get_joint_output(self):
11321081
torch.max(logits, dim=-1, out=(self.state.scores, self.state.labels))
11331082

11341083
if self.has_fusion_models():
1135-
for fusion_model_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
1084+
fusion_scores_list, fusion_states_candidates_list = self.advance_fusion_models(
1085+
fusion_states_list=self.state.fusion_states_list,
1086+
multi_biasing_ids=self.state.multi_biasing_ids,
1087+
float_dtype=self.state.float_dtype,
1088+
)
1089+
for fusion_model_idx in range(len(fusion_scores_list)):
11361090
# get fusion scores/states
1137-
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
1138-
states=self.state.fusion_states_list[fusion_model_idx],
1139-
**({"model_ids": self.state.multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
1091+
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(
1092+
fusion_states_candidates_list[fusion_model_idx]
11401093
)
1141-
if not fusion_model_with_params.is_multi_model:
1142-
fusion_scores *= fusion_model_with_params.alpha
1143-
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(fusion_states_candidates)
1144-
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype))
1094+
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores_list[fusion_model_idx])
11451095
# update logits with fusion scores
1146-
logits[:, :-1] += fusion_scores
1096+
logits[:, :-1] += fusion_scores_list[fusion_model_idx]
11471097
# get labels (greedy) and scores from current logits, replace labels/scores with new
11481098
scores_w_fusion, labels_w_fusion = logits[:, :-1].max(dim=-1)
11491099
# preserve "blank" / "non-blank" category

nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py

Lines changed: 17 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,10 @@
2020
import torch.nn.functional as F
2121
from omegaconf import DictConfig, ListConfig
2222

23-
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import (
24-
GPUBiasingMultiModel,
25-
GPUBiasingMultiModelBase,
26-
)
23+
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import GPUBiasingMultiModel
2724
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
2825
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
2926
BatchedLabelLoopingState,
30-
FusionModelWithParams,
3127
GreedyBatchedLabelLoopingComputerBase,
3228
LabelLoopingStateItem,
3329
SeparateGraphsLabelLooping,
@@ -292,33 +288,6 @@ def __init__(
292288
self.cuda_graphs_allow_fallback = True
293289
self.maybe_enable_cuda_graphs()
294290

295-
@property
296-
def per_stream_biasing_enabled(self):
297-
return self.biasing_multi_model is not None
298-
299-
def _all_fusion_models(
300-
self, with_multi_model: bool = True
301-
) -> list[NGramGPULanguageModel | GPUBiasingMultiModelBase]:
302-
if with_multi_model and self.per_stream_biasing_enabled:
303-
return self.fusion_models + [self.biasing_multi_model]
304-
return self.fusion_models
305-
306-
def _all_fusion_models_with_params(self, with_multi_model: bool = True) -> list[FusionModelWithParams]:
307-
models_with_params = [
308-
FusionModelWithParams(model=model, alpha=alpha, is_multi_model=False)
309-
for model, alpha in zip(self.fusion_models, self.fusion_models_alpha)
310-
]
311-
if with_multi_model and self.per_stream_biasing_enabled:
312-
models_with_params.append(
313-
FusionModelWithParams(model=self.biasing_multi_model, alpha=None, is_multi_model=True)
314-
)
315-
return models_with_params
316-
317-
def has_fusion_models(self, with_multi_model: bool = True) -> bool:
318-
if len(self.fusion_models) > 0:
319-
return True
320-
return with_multi_model and self.per_stream_biasing_enabled
321-
322291
def reset_cuda_graphs_state(self):
323292
"""Reset state to release memory (for CUDA graphs implementations)"""
324293
self.state = None
@@ -347,18 +316,6 @@ def _get_frame_confidence(self, logits: torch.Tensor, num_durations: int) -> Opt
347316
)
348317
)
349318

350-
def _move_fusion_models_to_device(self, device: torch.device):
351-
"""
352-
Move all fusion models to device.
353-
We need to do this since `self` is not nn.Module instance, but owns fusion models (nn.Module instances).
354-
"""
355-
with torch.inference_mode(mode=False):
356-
# NB: we avoid inference mode since otherwise all model params/buffers will be inference tensors,
357-
# which will make further inplace manipulations impossible
358-
# (e.g., `remove_model` for multi-model will throw errors)
359-
for fusion_model in self._all_fusion_models():
360-
fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually
361-
362319
def torch_impl(
363320
self,
364321
encoder_output: torch.Tensor,
@@ -467,21 +424,14 @@ def torch_impl(
467424
scores, labels = logits[:, :-num_durations].max(dim=-1)
468425

469426
if self.has_fusion_models():
470-
fusion_scores_list, fusion_states_candidates_list = [], []
427+
fusion_scores_combined, fusion_states_candidates_list = self.advance_fusion_models(
428+
fusion_states_list=fusion_states_list,
429+
multi_biasing_ids=multi_biasing_ids,
430+
float_dtype=float_dtype,
431+
)
471432
logits_with_fusion = logits.clone()
472-
for fusion_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
473-
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
474-
states=fusion_states_list[fusion_idx],
475-
**({"model_ids": multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
476-
)
477-
fusion_scores = fusion_scores.to(dtype=float_dtype)
478-
if not fusion_model_with_params.is_multi_model:
479-
fusion_scores *= fusion_model_with_params.alpha
480-
# combine logits with fusion model without blank
433+
for fusion_scores in fusion_scores_list:
481434
logits_with_fusion[:, : -num_durations - 1] += fusion_scores
482-
# save fusion scores and states candidates
483-
fusion_scores_list.append(fusion_scores)
484-
fusion_states_candidates_list.append(fusion_states_candidates)
485435

486436
# get max scores and labels without blank
487437
fusion_scores_max, fusion_labels_max = logits_with_fusion[:, : -num_durations - 1].max(dim=-1)
@@ -534,7 +484,6 @@ def torch_impl(
534484
if self.has_fusion_models():
535485
logits_with_fusion = logits.clone()
536486
for fusion_scores in fusion_scores_list:
537-
# combined scores with fusion model - without blank
538487
logits_with_fusion[:, : -num_durations - 1] += fusion_scores
539488
# get max scores and labels without blank
540489
more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, : -num_durations - 1].max(
@@ -1212,18 +1161,19 @@ def _before_inner_loop_get_joint_output(self):
12121161
)
12131162

12141163
if self.has_fusion_models():
1215-
for fusion_model_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
1164+
fusion_scores_list, fusion_states_candidates_list = self.advance_fusion_models(
1165+
fusion_states_list=self.state.fusion_states_list,
1166+
multi_biasing_ids=self.state.multi_biasing_ids,
1167+
float_dtype=self.state.float_dtype,
1168+
)
1169+
for fusion_model_idx in range(len(fusion_scores_list)):
12161170
# get fusion scores/states
1217-
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
1218-
states=self.state.fusion_states_list[fusion_model_idx],
1219-
**({"model_ids": self.state.multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
1171+
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(
1172+
fusion_states_candidates_list[fusion_model_idx]
12201173
)
1221-
if not fusion_model_with_params.is_multi_model:
1222-
fusion_scores *= fusion_model_with_params.alpha
1223-
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(fusion_states_candidates)
1224-
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype))
1174+
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores_list[fusion_model_idx])
12251175
# update logits with fusion scores
1226-
logits[:, : -self.state.model_durations.shape[0] - 1] += fusion_scores
1176+
logits[:, : -self.state.model_durations.shape[0] - 1] += fusion_scores_list[fusion_model_idx]
12271177
# get labels (greedy) and scores from current logits, replace labels/scores with new
12281178
scores_w_fusion, labels_w_fusion = logits[:, : -self.state.model_durations.shape[0] - 1].max(dim=-1)
12291179
# preserve "blank" / "non-blank" category

0 commit comments

Comments
 (0)