-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Add LLaVA audio (sound) model support #4402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
57e6c9f
Add LLaVA audio (sound) model support, FastConformer encoder, and CP-…
cuichenx 8ce38ba
Address review comments: parameterize align_corners and FP8 padding r…
cuichenx 933b330
remove HF model
cuichenx eb8b092
keep non-fastconformer HF model stubs and vision hf:// support
cuichenx 4f4ec4c
revert unintended
cuichenx 33a7e20
revert
cuichenx 86858c3
revert extra state mute
cuichenx e8bdacb
revert fp8_pad_hook to use fp8_recipe-aware padding
cuichenx a22f4ba
revert unrelated TE import refactor in llava_model.py
cuichenx 201eb75
add unit tests for multimodal CP helpers and RADIO state-dict hooks
cuichenx 76cc7c5
address review: keep pg_collection/vp_stage and einops guard in RADIO
cuichenx 17abb1d
restore ParakeetHuggingFaceModel using upstream HF FastConformer
cuichenx db7554e
fix(llava): guard sound_pad_to_clip_duration, document has_sounds sen…
cuichenx 95a639f
fix(fastconformer): use model dtype + FE sampling rate, key NeMo cach…
cuichenx e664f1e
fix(cp): require patch_dim, assert pre-embedder hidden dim, clarify l…
cuichenx 15757bf
fix(hf-module): tighten parakeet match to scheme + path-segment prefix
cuichenx 70cdffc
test: cover fastconformer wrapper, LLaVA sound integration, CP helper…
cuichenx 4d9a799
Merge branch 'main' into llava-model-audio
cuichenx e6d4da9
fix: address claude review feedback (nemo:// guard, redundant assignm…
cuichenx b373b26
fix: align unit-test logic with runtime contracts after CW-DFW dry run
cuichenx b6b1d03
chore: drop extra blank line after imports to satisfy CI isort
cuichenx 6091e22
fix(llava): restore packed dynamic-res image path and audio-only guard
cuichenx 03d8a57
fix(cp): split _split_num_frames in frame space; clean media-boundary…
cuichenx f75d38c
test: cover mixed-length video last-frame duplication in temporal gro…
cuichenx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
102 changes: 102 additions & 0 deletions
102
megatron/core/models/huggingface/fastconformer_model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| import torch | ||
|
|
||
| from megatron.core.models.huggingface import HuggingFaceModule | ||
|
|
||
| # NeMo model loading is slow, so cache the (preprocessor, encoder) tuple per | ||
| # `sound_model_type`. Keying by model id avoids returning a stale cached encoder | ||
| # when the same process constructs more than one Parakeet variant. | ||
| _NEMO_SOUND_MODEL_CACHE: dict[str, tuple] = {} | ||
|
|
||
|
|
||
| def get_nemo_sound_model(sound_model_type): | ||
| """Load (and cache) a NeMo ASR encoder + preprocessor for the given ``nemo://`` model id.""" | ||
| if sound_model_type not in _NEMO_SOUND_MODEL_CACHE: | ||
| import nemo.collections.asr as nemo_asr | ||
|
|
||
| asr_model = nemo_asr.models.ASRModel.from_pretrained( | ||
| model_name=sound_model_type.split("nemo://")[1] | ||
| ) | ||
| # Avoid hangs from an unnecessary max-seq-len NCCL sync in some edge cases. | ||
| asr_model.encoder.sync_max_audio_length = False | ||
| for layer in asr_model.encoder.layers: | ||
| layer.self_attn.use_pytorch_sdpa = True | ||
| _NEMO_SOUND_MODEL_CACHE[sound_model_type] = (asr_model.preprocessor, asr_model.encoder) | ||
| return _NEMO_SOUND_MODEL_CACHE[sound_model_type] | ||
|
|
||
|
|
||
| class ParakeetHuggingFaceModel(HuggingFaceModule): | ||
| """Wrapper for Parakeet sound encoders. | ||
|
|
||
| Supports two backends, selected by ``config.sound_model_type`` prefix: | ||
|
|
||
| - ``nemo://<model_name>`` loads a NeMo ASR encoder + preprocessor. | ||
| - ``hf://<model_name>`` loads the upstream Hugging Face FastConformer model | ||
| via ``transformers.AutoModel`` / ``AutoFeatureExtractor``. | ||
| """ | ||
|
|
||
| def __init__(self, config): | ||
| super().__init__(config) | ||
|
|
||
| self.use_nemo = config.sound_model_type.startswith("nemo://") | ||
| if self.use_nemo: | ||
| self.feature_extractor, self.model = get_nemo_sound_model(config.sound_model_type) | ||
|
|
||
| for module in self.model.modules(): | ||
| if module.__class__.__name__.lower() == "dropout": | ||
| module.p = config.hidden_dropout | ||
|
|
||
| if config.recompute_granularity is not None: | ||
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
| checkpoint_wrapper, | ||
| ) | ||
|
|
||
| self.model = checkpoint_wrapper(self.model) | ||
| elif config.sound_model_type.startswith("hf://"): | ||
| from transformers import AutoFeatureExtractor, AutoModel | ||
|
|
||
| sound_model_type = config.sound_model_type.split("hf://")[1] | ||
| self.feature_extractor = AutoFeatureExtractor.from_pretrained(sound_model_type) | ||
| self.model = AutoModel.from_pretrained(sound_model_type) | ||
|
|
||
| if config.recompute_granularity is not None: | ||
| self.model.gradient_checkpointing_enable() | ||
| else: | ||
| raise ValueError(f"Unknown sound model type: {config.sound_model_type}") | ||
|
|
||
| def _model_dtype(self) -> torch.dtype: | ||
| """Return the dtype of the encoder's first parameter (defaults to bf16).""" | ||
| for param in self.model.parameters(): | ||
| return param.dtype | ||
| return torch.bfloat16 | ||
|
|
||
| def _sampling_rate(self) -> int: | ||
| """Return the sampling rate the feature extractor expects (default 16 kHz).""" | ||
| return int(getattr(self.feature_extractor, "sampling_rate", 16000)) | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| """Forward pass returning (hidden_states, lengths). | ||
|
|
||
| Args: | ||
| args[0]: Sound clips tensor. | ||
| args[1]: Sound length tensor (used by NeMo backend; ignored for HF). | ||
| """ | ||
| if self.use_nemo: | ||
| features = self.feature_extractor(input_signal=args[0], length=args[1]) | ||
| y = self.model(audio_signal=features[0], length=features[1]) | ||
| # NeMo encoder returns [B, H, T]; LLaVA expects [B, T, H]. | ||
| return y[0].permute(0, 2, 1), y[1] | ||
| else: | ||
| # HF feature extractor expects audio as the first arg only, | ||
| # not (audio, length) as in NeMo. | ||
| sound_clips = args[0] | ||
| features = self.feature_extractor( | ||
| sound_clips, | ||
| **kwargs, | ||
| return_tensors="pt", | ||
| sampling_rate=self._sampling_rate(), | ||
| return_attention_mask=True, | ||
| ) | ||
| y = self.model(features.input_features.to(self._model_dtype()), features.attention_mask) | ||
| lengths = features.attention_mask.sum(dim=-1).to(y.last_hidden_state.device) | ||
| return y.last_hidden_state, lengths |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.