Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
57e6c9f
Add LLaVA audio (sound) model support, FastConformer encoder, and CP-…
cuichenx Apr 21, 2026
8ce38ba
Address review comments: parameterize align_corners and FP8 padding r…
cuichenx Apr 21, 2026
933b330
remove HF model
cuichenx Apr 22, 2026
eb8b092
keep non-fastconformer HF model stubs and vision hf:// support
cuichenx Apr 22, 2026
4f4ec4c
revert unintended
cuichenx Apr 22, 2026
33a7e20
revert
cuichenx Apr 22, 2026
86858c3
revert extra state mute
cuichenx Apr 22, 2026
e8bdacb
revert fp8_pad_hook to use fp8_recipe-aware padding
cuichenx Apr 22, 2026
a22f4ba
revert unrelated TE import refactor in llava_model.py
cuichenx Apr 22, 2026
201eb75
add unit tests for multimodal CP helpers and RADIO state-dict hooks
cuichenx Apr 22, 2026
76cc7c5
address review: keep pg_collection/vp_stage and einops guard in RADIO
cuichenx May 1, 2026
17abb1d
restore ParakeetHuggingFaceModel using upstream HF FastConformer
cuichenx May 1, 2026
db7554e
fix(llava): guard sound_pad_to_clip_duration, document has_sounds sen…
cuichenx May 1, 2026
95a639f
fix(fastconformer): use model dtype + FE sampling rate, key NeMo cach…
cuichenx May 1, 2026
e664f1e
fix(cp): require patch_dim, assert pre-embedder hidden dim, clarify l…
cuichenx May 1, 2026
15757bf
fix(hf-module): tighten parakeet match to scheme + path-segment prefix
cuichenx May 1, 2026
70cdffc
test: cover fastconformer wrapper, LLaVA sound integration, CP helper…
cuichenx May 1, 2026
4d9a799
Merge branch 'main' into llava-model-audio
cuichenx May 1, 2026
e6d4da9
fix: address claude review feedback (nemo:// guard, redundant assignm…
cuichenx May 1, 2026
b373b26
fix: align unit-test logic with runtime contracts after CW-DFW dry run
cuichenx May 1, 2026
b6b1d03
chore: drop extra blank line after imports to satisfy CI isort
cuichenx May 1, 2026
6091e22
fix(llava): restore packed dynamic-res image path and audio-only guard
cuichenx May 2, 2026
03d8a57
fix(cp): split _split_num_frames in frame space; clean media-boundary…
cuichenx May 4, 2026
f75d38c
test: cover mixed-length video last-frame duplication in temporal gro…
cuichenx May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions megatron/core/models/huggingface/fastconformer_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
import torch

from megatron.core.models.huggingface import HuggingFaceModule

# NeMo model loading is slow, so cache the (preprocessor, encoder) tuple per
# `sound_model_type`. Keying by model id avoids returning a stale cached encoder
# when the same process constructs more than one Parakeet variant.
_NEMO_SOUND_MODEL_CACHE: dict[str, tuple] = {}


def get_nemo_sound_model(sound_model_type):
"""Load (and cache) a NeMo ASR encoder + preprocessor for the given ``nemo://`` model id."""
if sound_model_type not in _NEMO_SOUND_MODEL_CACHE:
import nemo.collections.asr as nemo_asr

asr_model = nemo_asr.models.ASRModel.from_pretrained(
model_name=sound_model_type.split("nemo://")[1]
)
# Avoid hangs from an unnecessary max-seq-len NCCL sync in some edge cases.
asr_model.encoder.sync_max_audio_length = False
for layer in asr_model.encoder.layers:
layer.self_attn.use_pytorch_sdpa = True
_NEMO_SOUND_MODEL_CACHE[sound_model_type] = (asr_model.preprocessor, asr_model.encoder)
return _NEMO_SOUND_MODEL_CACHE[sound_model_type]


class ParakeetHuggingFaceModel(HuggingFaceModule):
"""Wrapper for Parakeet sound encoders.

Supports two backends, selected by ``config.sound_model_type`` prefix:

- ``nemo://<model_name>`` loads a NeMo ASR encoder + preprocessor.
- ``hf://<model_name>`` loads the upstream Hugging Face FastConformer model
via ``transformers.AutoModel`` / ``AutoFeatureExtractor``.
"""

def __init__(self, config):
super().__init__(config)

self.use_nemo = config.sound_model_type.startswith("nemo://")
if self.use_nemo:
self.feature_extractor, self.model = get_nemo_sound_model(config.sound_model_type)

for module in self.model.modules():
if module.__class__.__name__.lower() == "dropout":
module.p = config.hidden_dropout

if config.recompute_granularity is not None:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)

self.model = checkpoint_wrapper(self.model)
elif config.sound_model_type.startswith("hf://"):
from transformers import AutoFeatureExtractor, AutoModel

sound_model_type = config.sound_model_type.split("hf://")[1]
self.feature_extractor = AutoFeatureExtractor.from_pretrained(sound_model_type)
self.model = AutoModel.from_pretrained(sound_model_type)

if config.recompute_granularity is not None:
self.model.gradient_checkpointing_enable()
else:
raise ValueError(f"Unknown sound model type: {config.sound_model_type}")

def _model_dtype(self) -> torch.dtype:
"""Return the dtype of the encoder's first parameter (defaults to bf16)."""
for param in self.model.parameters():
return param.dtype
return torch.bfloat16

def _sampling_rate(self) -> int:
"""Return the sampling rate the feature extractor expects (default 16 kHz)."""
return int(getattr(self.feature_extractor, "sampling_rate", 16000))

def forward(self, *args, **kwargs):
"""Forward pass returning (hidden_states, lengths).

Args:
args[0]: Sound clips tensor.
args[1]: Sound length tensor (used by NeMo backend; ignored for HF).
"""
if self.use_nemo:
features = self.feature_extractor(input_signal=args[0], length=args[1])
y = self.model(audio_signal=features[0], length=features[1])
# NeMo encoder returns [B, H, T]; LLaVA expects [B, T, H].
return y[0].permute(0, 2, 1), y[1]
else:
# HF feature extractor expects audio as the first arg only,
# not (audio, length) as in NeMo.
sound_clips = args[0]
features = self.feature_extractor(
sound_clips,
**kwargs,
return_tensors="pt",
sampling_rate=self._sampling_rate(),
return_attention_mask=True,
)
y = self.model(features.input_features.to(self._model_dtype()), features.attention_mask)
lengths = features.attention_mask.sum(dim=-1).to(y.last_hidden_state.device)
return y.last_hidden_state, lengths
25 changes: 25 additions & 0 deletions megatron/core/models/huggingface/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ def get_hf_model_type(model_path):
"please install it with `pip install transformers`"
)

# Parakeet is a special case: its model id may be `nemo://...`, which
# AutoConfig cannot resolve, so detect it from the prefix. Require the
# `nemo://` or `hf://` scheme so unrelated local paths that happen to
# contain "parakeet" (e.g. a user directory) don't get misrouted.
lowered = model_path.lower()
if lowered.startswith(("nemo://", "hf://")):
model_id = lowered.split("://", 1)[1]
# Match a path segment whose name begins with "parakeet" (e.g.
# `nvidia/parakeet-tdt-0.6b-v2`). Substring-anywhere matches like
# `myparakeet-clone` are intentionally rejected.
if any(seg.startswith("parakeet") for seg in model_id.split("/")):
return "parakeet"
# Any other `nemo://` model can't be resolved by AutoConfig below;
# raise a clear error rather than letting `split("hf://")[1]` raise
# an IndexError with no context.
if lowered.startswith("nemo://"):
raise NotImplementedError(
f"nemo:// scheme is currently only supported for parakeet models, "
f"got {model_path}"
)

hf_config = AutoConfig.from_pretrained(model_path.split("hf://")[1])
Comment thread
cuichenx marked this conversation as resolved.
model_type = hf_config.architectures[0].lower()

Expand All @@ -91,6 +112,10 @@ def build_hf_model(config, model_path):
from megatron.core.models.huggingface.clip_model import SiglipHuggingFaceModel

model = SiglipHuggingFaceModel(config)
elif "parakeet" in model_type:
from megatron.core.models.huggingface.fastconformer_model import ParakeetHuggingFaceModel

model = ParakeetHuggingFaceModel(config)
else:
raise NotImplementedError(f"unsupported huggingface model {config.hf_config}")

Expand Down
Loading
Loading