Changes required for enabling prompt based models in Nemo Inference#15036
Changes required for enabling prompt based models in Nemo Inference#15036arushidNV merged 22 commits intoNVIDIA-NeMo:mainfrom
Conversation
3fa05c5 to
921122c
Compare
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for RNNT multilingual models with prompt-based language selection in NeMo Inference. The implementation enables language-specific prompts to be passed through the inference pipeline, allowing a single model to handle multiple languages.
Key Changes:
- Added
language_codefield toASRRequestOptionsfor specifying the target language per request - Introduced
prompt_idxtracking in streaming state to maintain language selection across stream lifecycle - Implemented prompt vector generation and caching infrastructure in the buffered RNNT pipeline with validation and efficient batch processing
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
nemo/collections/asr/inference/streaming/state/state.py |
Adds prompt_idx field and set_prompt_index() method to track prompt index per stream |
nemo/collections/asr/inference/streaming/framing/request_options.py |
Adds language_code field to request options for language selection |
nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py |
Extends encode methods to support optional prompt vectors with dimension expansion |
nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py |
Implements prompt support infrastructure including configuration loading, validation, and prompt vector batch generation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -212,6 +304,15 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState: | |||
| default_asr_output_granularity=self.asr_output_granularity, | |||
There was a problem hiding this comment.
The augment_with_defaults call is missing the default_language_code parameter. According to the updated signature in request_options.py, this parameter should be passed here. Without it, the default language code will always be None even if a default was intended to be set.
Consider adding:
new_options = options.augment_with_defaults(
default_enable_itn=self.text_processor.is_itn_enabled(),
default_enable_pnc=self.text_processor.is_pnc_enabled(),
default_stop_history_eou=self.stop_history_eou_in_milliseconds,
default_asr_output_granularity=self.asr_output_granularity,
default_language_code=None, # or an appropriate default
)| default_asr_output_granularity=self.asr_output_granularity, | |
| default_asr_output_granularity=self.asr_output_granularity, | |
| default_language_code=None, # or "en-US" if a default is desired |
| # Build prompt vectors if prompts are enabled | ||
| if self.prompt_enabled: | ||
| requests_states = [self.get_state(f.stream_id) for f in frames] | ||
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) |
There was a problem hiding this comment.
The code assumes s.prompt_idx is set for all states when self.prompt_enabled is True. However, if a state was created before the prompt feature was added or if set_prompt_index() wasn't called for some reason, this will cause an AttributeError when accessing s.prompt_idx.
Consider adding a check or a default value:
indices = torch.tensor([getattr(s, 'prompt_idx', 0) for s in requests_states], device=self.device, dtype=torch.long)Or better, ensure prompt_idx is always initialized in the state's _reset_streaming_state method.
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | |
| indices = torch.tensor([getattr(s, 'prompt_idx', 0) for s in requests_states], device=self.device, dtype=torch.long) |
| # Build prompt vectors if prompts are enabled | ||
| if self.prompt_enabled: | ||
| requests_states = [self.get_state(f.stream_id) for f in fbuffers] | ||
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) |
There was a problem hiding this comment.
Same issue as line 402: the code assumes s.prompt_idx is always set for all states. If a state was created before the prompt feature was added or if set_prompt_index() wasn't called, this will cause an AttributeError.
Consider adding a check or default value as in line 402.
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | |
| indices = torch.tensor([getattr(s, "prompt_idx", 0) for s in requests_states], device=self.device, dtype=torch.long) |
| Args: | ||
| processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). | ||
| processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape is torch.Size([B, num_prompts]). |
There was a problem hiding this comment.
The docstring states that prompt_vectors should have shape [B, num_prompts], but the method encode_with_prompts() expands it to [B, T_enc, num_prompts] before calling this method. This creates an inconsistency in the expected shape documentation.
Either:
- Update the docstring to clarify that both shapes are accepted, or
- Document that this method accepts time-expanded prompts when called directly
Consider updating the docstring to:
prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models.
Shape can be torch.Size([B, num_prompts]) or torch.Size([B, T_enc, num_prompts]) if already expanded.| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape is torch.Size([B, num_prompts]). | |
| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape can be torch.Size([B, num_prompts]) or torch.Size([B, T_enc, num_prompts]) if already expanded. |
There was a problem hiding this comment.
please fix prompt_vectors shape in the docstring
| """ | ||
| return input_time_steps // self.get_subsampling_factor() |
There was a problem hiding this comment.
The time step estimation uses integer division which may not accurately predict the actual encoder output time steps. Depending on the encoder architecture (e.g., with padding, different convolution parameters), the actual output time steps might differ from input_time_steps // subsampling_factor.
This could cause issues when the expanded prompt tensor has a different time dimension than the actual encoder expects. Consider:
- Using the actual encoder's time estimation logic if available
- Adding a safety buffer or validation
- Documenting this as an approximation that may need adjustment
Example:
# Add some buffer or use encoder's actual logic
return (input_time_steps + subsampling_factor - 1) // subsampling_factor| """ | |
| return input_time_steps // self.get_subsampling_factor() | |
| Note: This is an approximation. For most encoders, the output time steps are calculated as | |
| (input_time_steps + subsampling_factor - 1) // subsampling_factor to avoid underestimation. | |
| """ | |
| subsampling_factor = self.get_subsampling_factor() | |
| return (input_time_steps + subsampling_factor - 1) // subsampling_factor |
| # Use "en-US" as the default prompt for zero encoding | ||
| # This region is sliced out before decoding, so language choice doesn't matter | ||
| default_prompt_idx = self._resolve_prompt_index("en-US") |
There was a problem hiding this comment.
The hardcoded "en-US" language code assumes this language will always be present in the prompt dictionary. If a prompt-enabled model doesn't include "en-US" in its prompt dictionary, this will cause a ValueError during initialization in init_zero_enc().
Consider either:
- Using the first available language from the prompt dictionary
- Making this configurable
- Adding validation at initialization to ensure "en-US" exists
Example fix:
# Get the first available language or use a configurable default
available_languages = list(self._prompt_config['prompt_dict'].keys())
default_lang = available_languages[0] if available_languages else "en-US"
default_prompt_idx = self._resolve_prompt_index(default_lang)| # Use "en-US" as the default prompt for zero encoding | |
| # This region is sliced out before decoding, so language choice doesn't matter | |
| default_prompt_idx = self._resolve_prompt_index("en-US") | |
| # Use the first available language as the default prompt for zero encoding | |
| # This region is sliced out before decoding, so language choice doesn't matter | |
| available_languages = list(self._prompt_config['prompt_dict'].keys()) | |
| default_lang = available_languages[0] if available_languages else "en-US" | |
| default_prompt_idx = self._resolve_prompt_index(default_lang) |
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | ||
| # Validate indices | ||
| num_prompts = self._prompt_config['num_prompts'] | ||
| if torch.any((indices < 0) | (indices >= num_prompts)): | ||
| raise ValueError("Found out-of-range prompt index in batch.") | ||
| prompt_matrix = self._get_prompt_matrix() | ||
| prompt_vectors = prompt_matrix.index_select(0, indices) # [B, num_prompts] |
There was a problem hiding this comment.
The logic for building prompt vectors (lines 400-408) is duplicated in encode_processed_signals() (lines 454-461). Consider extracting this into a helper method to reduce code duplication and improve maintainability.
Example:
def _build_prompt_vectors(self, states: list) -> Tensor:
"""Build prompt vectors for a batch of states."""
indices = torch.tensor([s.prompt_idx for s in states], device=self.device, dtype=torch.long)
num_prompts = self._prompt_config['num_prompts']
if torch.any((indices < 0) | (indices >= num_prompts)):
raise ValueError("Found out-of-range prompt index in batch.")
prompt_matrix = self._get_prompt_matrix()
return prompt_matrix.index_select(0, indices)| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | |
| # Validate indices | |
| num_prompts = self._prompt_config['num_prompts'] | |
| if torch.any((indices < 0) | (indices >= num_prompts)): | |
| raise ValueError("Found out-of-range prompt index in batch.") | |
| prompt_matrix = self._get_prompt_matrix() | |
| prompt_vectors = prompt_matrix.index_select(0, indices) # [B, num_prompts] | |
| prompt_vectors = self._build_prompt_vectors(requests_states) |
| self._prompt_matrix_cache = {} | ||
|
|
There was a problem hiding this comment.
When prompt_enabled is True (line 180) but _load_prompt_config() returns an empty dict (line 212), the _prompt_config will be empty. However, the code at lines 271-274 will attempt to use it, causing a RuntimeError at line 226.
This indicates a configuration issue, but it happens at runtime rather than initialization. Consider adding validation in init_prompt_support():
if self.prompt_enabled:
self._prompt_config = self._load_prompt_config()
if not self._prompt_config:
raise RuntimeError(
"Model has concat=True but prompt configuration (num_prompts, prompt_dictionary) "
"is missing or invalid in model_defaults."
)
self._prompt_matrix_cache = {}| self._prompt_matrix_cache = {} | |
| if not self._prompt_config: | |
| raise RuntimeError( | |
| "Model has concat=True but prompt configuration (num_prompts, prompt_dictionary) " | |
| "is missing or invalid in model_defaults." | |
| ) | |
| self._prompt_matrix_cache = {} |
| Args: | ||
| processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). | ||
| processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape is torch.Size([B, num_prompts]). |
There was a problem hiding this comment.
please fix prompt_vectors shape in the docstring
| Returns: | ||
| (tuple[Tensor, Tensor]) encoder output and encoder output length. | ||
| """ | ||
| encoder_time_steps = self._estimate_encoder_time_steps(processed_signal.shape[2]) |
There was a problem hiding this comment.
It looks like _estimate_encoder_time_steps is a single-line method. There’s no need to keep it separate, since it contains only one line and isn’t reused elsewhere.
There was a problem hiding this comment.
There is duplicated code for building prompts in both encode_raw_signals and encode_processed_signals. It would be better to move this logic into a separate helper method.
Signed-off-by: arushid <arushid@nvidia.com>
Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: arushid <arushid@nvidia.com>
Signed-off-by: arushid <arushid@nvidia.com>
Signed-off-by: arushidNV <arushidNV@users.noreply.github.com>
Signed-off-by: arushid <arushid@nvidia.com>
Signed-off-by: arushid <arushid@nvidia.com>
Signed-off-by: arushid <arushid@nvidia.com>
Signed-off-by: arushidNV <arushid@nvidia.com>
artbataev
left a comment
There was a problem hiding this comment.
Please fix an artifact from merging (see my comment in the code).
Everything else looks good, thank you!
| granularity = self._with_default(self.asr_output_granularity, default_asr_output_granularity) | ||
|
|
||
| return ASRRequestOptions( | ||
| enable_itn=default_enable_itn if self.enable_itn is None else self.enable_itn, |
There was a problem hiding this comment.
@arushidNV looks like there are artifacts from merging main to your PR. You pass enable_itn, enable_pnc and other arguments twice.
Should be fixed
There was a problem hiding this comment.
This also breaks tests:
SyntaxError: keyword argument repeated:
enable_itn
There was a problem hiding this comment.
Thank you for the catching it! I have pushed the fix.
…VIDIA-NeMo#15036) * Changes required for enabling prompt based models in Nemo Inference Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Add prompt support to zero encoding to enable right padding Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: arushid <arushid@nvidia.com> * Fixed MR Reviews Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Using one hot vector instead of caching Signed-off-by: arushid <arushid@nvidia.com> * Add prompt support to cache-aware pipeline Signed-off-by: arushid <arushid@nvidia.com> * Fixing failign CI tests Signed-off-by: arushid <arushid@nvidia.com> * Changes required for enabling prompt based models in Nemo Inference Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Add prompt support to zero encoding to enable right padding Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: arushid <arushid@nvidia.com> * Fixed MR Reviews Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Using one hot vector instead of caching Signed-off-by: arushid <arushid@nvidia.com> * Add prompt support to cache-aware pipeline Signed-off-by: arushid <arushid@nvidia.com> * Fixing failign CI tests Signed-off-by: arushid <arushid@nvidia.com> --------- Signed-off-by: arushid <arushid@nvidia.com> Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: arushidNV <arushid@nvidia.com> Co-authored-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: Akhil Varanasi <akhilvaranasi23@gmail.com>
…VIDIA-NeMo#15036) * Changes required for enabling prompt based models in Nemo Inference Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Add prompt support to zero encoding to enable right padding Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: arushid <arushid@nvidia.com> * Fixed MR Reviews Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Using one hot vector instead of caching Signed-off-by: arushid <arushid@nvidia.com> * Add prompt support to cache-aware pipeline Signed-off-by: arushid <arushid@nvidia.com> * Fixing failign CI tests Signed-off-by: arushid <arushid@nvidia.com> * Changes required for enabling prompt based models in Nemo Inference Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Add prompt support to zero encoding to enable right padding Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: arushid <arushid@nvidia.com> * Fixed MR Reviews Signed-off-by: arushid <arushid@nvidia.com> * Apply isort and black reformatting Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> * Using one hot vector instead of caching Signed-off-by: arushid <arushid@nvidia.com> * Add prompt support to cache-aware pipeline Signed-off-by: arushid <arushid@nvidia.com> * Fixing failign CI tests Signed-off-by: arushid <arushid@nvidia.com> --------- Signed-off-by: arushid <arushid@nvidia.com> Signed-off-by: arushidNV <arushidNV@users.noreply.github.com> Signed-off-by: arushidNV <arushid@nvidia.com> Co-authored-by: arushidNV <arushidNV@users.noreply.github.com>
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Adds support for RNNT multilingual model with prompt input in Nemo Inference
Collection: ASR
Changelog
language_codefield toASRRequestOptionsfor specifying target languageprompt_idxfield andset_prompt_index()method for creating prompt vector for each stream,prompt_vectorsparameter toencode()methodencode_with_prompts()for prompt modelsencode_raw_signals()andencode_processed_signals()to apply promptsUsage
# Add a code snippet demonstrating how to use thisGitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information