2020import torch .nn .functional as F
2121from omegaconf import DictConfig
2222
23- from nemo .collections .asr .parts .context_biasing .biasing_multi_model import (
24- GPUBiasingMultiModel ,
25- GPUBiasingMultiModelBase ,
26- )
23+ from nemo .collections .asr .parts .context_biasing .biasing_multi_model import GPUBiasingMultiModel
2724from nemo .collections .asr .parts .submodules .ngram_lm import NGramGPULanguageModel
2825from nemo .collections .asr .parts .submodules .transducer_decoding .label_looping_base import (
2926 BatchedLabelLoopingState ,
30- FusionModelWithParams ,
3127 GreedyBatchedLabelLoopingComputerBase ,
3228 LabelLoopingStateItem ,
3329 SeparateGraphsLabelLooping ,
@@ -265,33 +261,6 @@ def __init__(
265261 self .cuda_graphs_allow_fallback = True
266262 self .maybe_enable_cuda_graphs ()
267263
268- @property
269- def per_stream_biasing_enabled (self ):
270- return self .biasing_multi_model is not None
271-
272- def _all_fusion_models (
273- self , with_multi_model : bool = True
274- ) -> list [NGramGPULanguageModel | GPUBiasingMultiModelBase ]:
275- if with_multi_model and self .per_stream_biasing_enabled :
276- return self .fusion_models + [self .biasing_multi_model ]
277- return self .fusion_models
278-
279- def _all_fusion_models_with_params (self , with_multi_model : bool = True ) -> list [FusionModelWithParams ]:
280- models_with_params = [
281- FusionModelWithParams (model = model , alpha = alpha , is_multi_model = False )
282- for model , alpha in zip (self .fusion_models , self .fusion_models_alpha )
283- ]
284- if with_multi_model and self .per_stream_biasing_enabled :
285- models_with_params .append (
286- FusionModelWithParams (model = self .biasing_multi_model , alpha = None , is_multi_model = True )
287- )
288- return models_with_params
289-
290- def has_fusion_models (self , with_multi_model : bool = True ) -> bool :
291- if len (self .fusion_models ) > 0 :
292- return True
293- return with_multi_model and self .per_stream_biasing_enabled
294-
295264 def reset_cuda_graphs_state (self ):
296265 """Reset state to release memory (for CUDA graphs implementations)"""
297266 self .state = None
@@ -306,18 +275,6 @@ def _get_frame_confidence(self, logits: torch.Tensor) -> Optional[torch.Tensor]:
306275 else None
307276 )
308277
309- def _move_fusion_models_to_device (self , device : torch .device ):
310- """
311- Move all fusion models to device.
312- We need to do this since `self` is not nn.Module instance, but owns fusion models (nn.Module instances).
313- """
314- with torch .inference_mode (mode = False ):
315- # NB: we avoid inference mode since otherwise all model params/buffers will be inference tensors,
316- # which will make further inplace manipulations impossible
317- # (e.g., `remove_model` for multi-model will throw errors)
318- for fusion_model in self ._all_fusion_models ():
319- fusion_model .to (device ) # fusion_models is nn.Module, but self is not; need to move manually
320-
321278 def torch_impl (
322279 self ,
323280 encoder_output : torch .Tensor ,
@@ -416,21 +373,14 @@ def torch_impl(
416373 scores , labels = logits .max (- 1 )
417374
418375 if self .has_fusion_models ():
419- fusion_scores_list , fusion_states_candidates_list = [], []
376+ fusion_scores_list , fusion_states_candidates_list = self .advance_fusion_models (
377+ fusion_states_list = fusion_states_list ,
378+ multi_biasing_ids = multi_biasing_ids ,
379+ float_dtype = float_dtype ,
380+ )
420381 logits_with_fusion = logits .clone ()
421- for fusion_idx , fusion_model_with_params in enumerate (self ._all_fusion_models_with_params ()):
422- fusion_scores , fusion_states_candidates = fusion_model_with_params .model .advance (
423- states = fusion_states_list [fusion_idx ],
424- ** ({"model_ids" : multi_biasing_ids } if fusion_model_with_params .is_multi_model else {}),
425- )
426- fusion_scores = fusion_scores .to (dtype = float_dtype )
427- if not fusion_model_with_params .is_multi_model :
428- fusion_scores *= fusion_model_with_params .alpha
429- # combine logits with fusion model without blank
382+ for fusion_scores in fusion_scores_list :
430383 logits_with_fusion [:, :- 1 ] += fusion_scores
431- # save fusion scores and states candidates
432- fusion_scores_list .append (fusion_scores )
433- fusion_states_candidates_list .append (fusion_states_candidates )
434384
435385 # get max scores and labels without blank
436386 fusion_scores_max , fusion_labels_max = logits_with_fusion [:, :- 1 ].max (dim = - 1 )
@@ -478,7 +428,6 @@ def torch_impl(
478428 if self .has_fusion_models ():
479429 logits_with_fusion = logits .clone ()
480430 for fusion_scores in fusion_scores_list :
481- # combined scores with fusion model - without blank
482431 logits_with_fusion [:, :- 1 ] += fusion_scores
483432 # get max scores and labels without blank
484433 more_scores_w_fusion , more_labels_w_fusion = logits_with_fusion [:, :- 1 ].max (dim = - 1 )
@@ -1132,18 +1081,19 @@ def _before_inner_loop_get_joint_output(self):
11321081 torch .max (logits , dim = - 1 , out = (self .state .scores , self .state .labels ))
11331082
11341083 if self .has_fusion_models ():
1135- for fusion_model_idx , fusion_model_with_params in enumerate (self ._all_fusion_models_with_params ()):
1084+ fusion_scores_list , fusion_states_candidates_list = self .advance_fusion_models (
1085+ fusion_states_list = self .state .fusion_states_list ,
1086+ multi_biasing_ids = self .state .multi_biasing_ids ,
1087+ float_dtype = self .state .float_dtype ,
1088+ )
1089+ for fusion_model_idx in range (len (fusion_scores_list )):
11361090 # get fusion scores/states
1137- fusion_scores , fusion_states_candidates = fusion_model_with_params .model .advance (
1138- states = self .state .fusion_states_list [fusion_model_idx ],
1139- ** ({"model_ids" : self .state .multi_biasing_ids } if fusion_model_with_params .is_multi_model else {}),
1091+ self .state .fusion_states_candidates_list [fusion_model_idx ].copy_ (
1092+ fusion_states_candidates_list [fusion_model_idx ]
11401093 )
1141- if not fusion_model_with_params .is_multi_model :
1142- fusion_scores *= fusion_model_with_params .alpha
1143- self .state .fusion_states_candidates_list [fusion_model_idx ].copy_ (fusion_states_candidates )
1144- self .state .fusion_scores_list [fusion_model_idx ].copy_ (fusion_scores .to (dtype = self .state .float_dtype ))
1094+ self .state .fusion_scores_list [fusion_model_idx ].copy_ (fusion_scores_list [fusion_model_idx ])
11451095 # update logits with fusion scores
1146- logits [:, :- 1 ] += fusion_scores
1096+ logits [:, :- 1 ] += fusion_scores_list [ fusion_model_idx ]
11471097 # get labels (greedy) and scores from current logits, replace labels/scores with new
11481098 scores_w_fusion , labels_w_fusion = logits [:, :- 1 ].max (dim = - 1 )
11491099 # preserve "blank" / "non-blank" category
0 commit comments