Skip to content

Commit ac2795e

Browse files
genquan9tango4jchtruong814naymaraqnaymaraq
authored
Update Gemma3VL model training scripts (#15041)
* Fixing lines for multispeaker pipeline (#15030) * Fixing lines for multispeaker pipeline Signed-off-by: taejinp <tango4j@gmail.com> * Removing unused imports Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: tango4j <tango4j@users.noreply.github.com> * Making changes for HF Space deployment Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: chtruong814 <chtruong814@users.noreply.github.com> * Updated multispk trans utils. Signed-off-by: taejinp <tango4j@gmail.com> --------- Signed-off-by: taejinp <tango4j@gmail.com> Signed-off-by: tango4j <tango4j@users.noreply.github.com> Signed-off-by: chtruong814 <chtruong814@users.noreply.github.com> Co-authored-by: tango4j <tango4j@users.noreply.github.com> Co-authored-by: chtruong814 <chtruong814@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> * Support gemma3vl tuning with verified performances Signed-off-by: genquan9 <genquan@google.com> * minor update gemma3vl parameters for easier usages Signed-off-by: genquan9 <genquan@google.com> * Apply isort and black reformatting Signed-off-by: genquan9 <genquan9@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> * Inference optimization for cache-aware pipelines (#15035) * optimize context manager and cache feature bufferer Signed-off-by: naymaraq <dkaramyan@nvidia.com> * speedUp cache_feature_bufferer Signed-off-by: naymaraq <dkaramyan@nvidia.com> * improved docstring in BatchedCacheFeatureBufferer Signed-off-by: naymaraq <dkaramyan@nvidia.com> --------- Signed-off-by: naymaraq <dkaramyan@nvidia.com> Co-authored-by: naymaraq <dkaramyan@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * fix loading of hyb ctc rnnt bpe models when using from pretrained (#15042) * fix loading of hyb ctc rnnt bpe models when using from pretrained Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * Apply isort and black reformatting Signed-off-by: nithinraok <nithinraok@users.noreply.github.com> --------- Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Signed-off-by: nithinraok <nithinraok@users.noreply.github.com> Co-authored-by: nithinraok <nithinraok@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> * revert ckpt scripts removal from #14617 (#15048) Signed-off-by: genquan9 <genquan@google.com> * chore: remove ExportDeploy (#15033) * add EP in PTQ (#15015) Signed-off-by: jenchen13 <jennifchen@nvidia.com> Signed-off-by: Pablo Garay <pagaray@nvidia.com> * remove ExportDeploy Signed-off-by: Pablo Garay <pagaray@nvidia.com> * remove exportDeploy tests Signed-off-by: Pablo Garay <pagaray@nvidia.com> * remove references Signed-off-by: Pablo Garay <pagaray@nvidia.com> * lintfix Signed-off-by: Pablo Garay <pagaray@nvidia.com> * Fixing lines for multispeaker pipeline (#15030) * Fixing lines for multispeaker pipeline Signed-off-by: taejinp <tango4j@gmail.com> * Removing unused imports Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: tango4j <tango4j@users.noreply.github.com> * Making changes for HF Space deployment Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: chtruong814 <chtruong814@users.noreply.github.com> * Updated multispk trans utils. Signed-off-by: taejinp <tango4j@gmail.com> --------- Signed-off-by: taejinp <tango4j@gmail.com> Signed-off-by: tango4j <tango4j@users.noreply.github.com> Signed-off-by: chtruong814 <chtruong814@users.noreply.github.com> Co-authored-by: tango4j <tango4j@users.noreply.github.com> Co-authored-by: chtruong814 <chtruong814@users.noreply.github.com> Signed-off-by: Pablo Garay <pagaray@nvidia.com> * remove ExportDeploy & references Signed-off-by: Pablo Garay <pagaray@nvidia.com> * lintfix Signed-off-by: Pablo Garay <pagaray@nvidia.com> * get load_ckpt back Signed-off-by: Pablo Garay <pagaray@nvidia.com> * lintfix Signed-off-by: Pablo Garay <pagaray@nvidia.com> * Apply isort and black reformatting Signed-off-by: pablo-garay <pablo-garay@users.noreply.github.com> * back Signed-off-by: Pablo Garay <pagaray@nvidia.com> * revert back Signed-off-by: Pablo Garay <pagaray@nvidia.com> * revert back Signed-off-by: Pablo Garay <pagaray@nvidia.com> * remove ExportDeploy Signed-off-by: Pablo Garay <pagaray@nvidia.com> --------- Signed-off-by: jenchen13 <jennifchen@nvidia.com> Signed-off-by: Pablo Garay <pagaray@nvidia.com> Signed-off-by: taejinp <tango4j@gmail.com> Signed-off-by: tango4j <tango4j@users.noreply.github.com> Signed-off-by: chtruong814 <chtruong814@users.noreply.github.com> Signed-off-by: pablo-garay <pablo-garay@users.noreply.github.com> Co-authored-by: Jenny Chen <jennifchen@nvidia.com> Co-authored-by: Taejin Park <tango4j@gmail.com> Co-authored-by: tango4j <tango4j@users.noreply.github.com> Co-authored-by: chtruong814 <chtruong814@users.noreply.github.com> Co-authored-by: pablo-garay <pablo-garay@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> * fix after ED remove (#15051) Signed-off-by: Pablo Garay <pagaray@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * Update changelog for `v2.5.3` (#15055) * beep boop: Update changelog Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update changelog for 2.5.3 Signed-off-by: Charlie Truong <chtruong@nvidia.com> --------- Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Charlie Truong <chtruong@nvidia.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Charlie Truong <chtruong@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * [voice agent] Fix RTVI missing bot message (#15068) * fix RTVI missing bot message, fix diar not passing VAD frames Signed-off-by: stevehuang52 <heh@nvidia.com> * revert change to diar Signed-off-by: stevehuang52 <heh@nvidia.com> --------- Signed-off-by: stevehuang52 <heh@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * [voice agent] make parakeet-eou model default stt (#15069) * make eou model default stt Signed-off-by: stevehuang52 <heh@nvidia.com> * fix typo Signed-off-by: stevehuang52 <heh@nvidia.com> * clean up doc Signed-off-by: stevehuang52 <heh@nvidia.com> --------- Signed-off-by: stevehuang52 <heh@nvidia.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> * minor fixes to remove unused headers/lines and add exception Signed-off-by: genquan9 <genquan@google.com> * resolve merge conflicts from github Signed-off-by: genquan9 <genquan@google.com> * removed old buffered CTC script (#15061) * removed old buffered CTC script Signed-off-by: naymaraq <dkaramyan@nvidia.com> * remove references to speech_to_text_buffered_infer_ctc.py Signed-off-by: naymaraq <dkaramyan@nvidia.com> --------- Signed-off-by: naymaraq <dkaramyan@nvidia.com> Co-authored-by: naymaraq <dkaramyan@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * remove unused imports Signed-off-by: genquan9 <genquan@google.com> * remove nlp related notebooks (#15070) Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Signed-off-by: genquan9 <genquan@google.com> * chore: Remove Automodel module (#15044) * Delete Automodel module Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Remove additional code using or importing automodel pathway Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Remove unused import Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Further remove hf automodel testing and hf automodel in vlm Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Remove unused vars Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Remove automodel instance in model opt Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Remove hf_auto_model_for_causal_ln Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Delete HFAutomodel from speech Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Add noqa Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Apply isort and black reformatting Signed-off-by: thomasdhc <thomasdhc@users.noreply.github.com> * Remove automodel related tests Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Update init file to use import Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> * Apply isort and black reformatting Signed-off-by: thomasdhc <thomasdhc@users.noreply.github.com> --------- Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> Signed-off-by: thomasdhc <thomasdhc@users.noreply.github.com> Co-authored-by: thomasdhc <thomasdhc@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> * add support for parallel ckpt removal (#15073) * add support for parallel ckpt removal Signed-off-by: dimapihtar <dpihtar@gmail.com> * Apply isort and black reformatting Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com> --------- Signed-off-by: dimapihtar <dpihtar@gmail.com> Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com> Co-authored-by: dimapihtar <dimapihtar@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> * Fix vlm engine changes in mcore (#15076) Signed-off-by: genquan9 <genquan@google.com> * Add docstring for encode_vqa_sample_multi_turns, and fix long comments Signed-off-by: genquan9 <genquan@google.com> * Update MagpieTTS model with latest changes (#15031) * Update MagpieTTS Signed-off-by: Jason <jasoli@nvidia.com> * allow None in dataset path Signed-off-by: Jason <jasoli@nvidia.com> * try to fix test by removing lhotse; fix yamls in fast dev run tests Signed-off-by: Jason <jasoli@nvidia.com> * increase zeroshot cer value; attempt to fix PO test; add back lhotse in parakeet inference to test segmentation fault Signed-off-by: Jason <jasoli@nvidia.com> * remove branch from test Signed-off-by: Jason <jasoli@nvidia.com> * use batch_size 1 Signed-off-by: Jason <jasoli@nvidia.com> * update GRPO test script Signed-off-by: Jason <jasoli@nvidia.com> * add use_lhotse as a param to transcribe; attempt to fix PO test again; attempt to catch error Signed-off-by: Jason <jasoli@nvidia.com> * fix tests Signed-off-by: Jason <jasoli@nvidia.com> * update rnnt transcribe; fix po test again Signed-off-by: Jason <jasoli@nvidia.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Move FCD copyright text from TorchEval to top of file Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Remove duplicate copyright text It is now at the top of the file. Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Fix OnlinePO test: escape a special character in command line Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Easier-to-read way to quote a special character in OnlinePO test Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Work around ASR Lhotse issue ... and remove some debug code. Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Remove FCD metric for now Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Remove unused import Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Update examples/tts/conf/magpietts/magpietts_lhotse.yaml Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> Signed-off-by: Roy Fejgin <rfejgin@nvidia.com> --------- Signed-off-by: Jason <jasoli@nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> Signed-off-by: Roy Fejgin <rfejgin@nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Fejgin, Roy <rfejgin@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * Revert "Fix vlm engine changes in mcore (#15076)" (#15090) This reverts commit b557cfd. Signed-off-by: genquan9 <genquan@google.com> * ASR inference: expose RNN-T decoding params for context biasing (#15091) * ASR Inference: load decoding params from config for RNN-T Signed-off-by: Vladimir Bataev <vbataev@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * Fix vlm engine changes in mcore (#15076) Signed-off-by: genquan9 <genquan@google.com> * Revert "Fix vlm engine changes in mcore (#15076)" (#15090) This reverts commit b557cfd. Signed-off-by: genquan9 <genquan@google.com> * update notebook (#15093) Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Signed-off-by: genquan9 <genquan@google.com> * fix lines with malformed anchor tags (#15095) Signed-off-by: Pablo Garay <pagaray@nvidia.com> Signed-off-by: genquan9 <genquan@google.com> * add copyright header for missing files Signed-off-by: genquan9 <genquan@google.com> --------- Signed-off-by: taejinp <tango4j@gmail.com> Signed-off-by: tango4j <tango4j@users.noreply.github.com> Signed-off-by: chtruong814 <chtruong814@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com> Signed-off-by: genquan9 <genquan9@users.noreply.github.com> Signed-off-by: naymaraq <dkaramyan@nvidia.com> Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Signed-off-by: nithinraok <nithinraok@users.noreply.github.com> Signed-off-by: jenchen13 <jennifchen@nvidia.com> Signed-off-by: Pablo Garay <pagaray@nvidia.com> Signed-off-by: pablo-garay <pablo-garay@users.noreply.github.com> Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Charlie Truong <chtruong@nvidia.com> Signed-off-by: stevehuang52 <heh@nvidia.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Signed-off-by: Dong Hyuk Chang <donghyukc@nvidia.com> Signed-off-by: thomasdhc <thomasdhc@users.noreply.github.com> Signed-off-by: dimapihtar <dpihtar@gmail.com> Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com> Signed-off-by: Jason <jasoli@nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> Signed-off-by: Roy Fejgin <rfejgin@nvidia.com> Signed-off-by: Vladimir Bataev <vbataev@nvidia.com> Co-authored-by: Taejin Park <tango4j@gmail.com> Co-authored-by: tango4j <tango4j@users.noreply.github.com> Co-authored-by: chtruong814 <chtruong814@users.noreply.github.com> Co-authored-by: genquan9 <genquan9@users.noreply.github.com> Co-authored-by: Dav Karamyan <47416614+naymaraq@users.noreply.github.com> Co-authored-by: naymaraq <dkaramyan@nvidia.com> Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com> Co-authored-by: nithinraok <nithinraok@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: Pablo Garay <palenq@gmail.com> Co-authored-by: Jenny Chen <jennifchen@nvidia.com> Co-authored-by: pablo-garay <pablo-garay@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Charlie Truong <chtruong@nvidia.com> Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Co-authored-by: Dong Hyuk Chang <thomaschang26@tutanota.com> Co-authored-by: thomasdhc <thomasdhc@users.noreply.github.com> Co-authored-by: dimapihtar <dimapihtar@users.noreply.github.com> Co-authored-by: meatybobby <bobchen@nvidia.com> Co-authored-by: Jason <jasoli@nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Fejgin, Roy <rfejgin@nvidia.com> Co-authored-by: Vladimir Bataev <vbataev@nvidia.com>
1 parent 1b4d6dc commit ac2795e

File tree

9 files changed

+423
-120
lines changed

9 files changed

+423
-120
lines changed

examples/voice_agent/server/server_configs/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ tts:
5151
type: kokoro # choices in ['nemo', 'kokoro']
5252
model: "hexgrad/Kokoro-82M"
5353
model_config: "./server_configs/tts_configs/kokoro_82M.yaml"
54-
device: "cuda"
54+
device: "cuda"

nemo/collections/llm/gpt/model/gemma3.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pathlib import Path
2222
from typing import TYPE_CHECKING, Annotated, Callable, Optional, Tuple, Union
2323

24+
import torch
2425
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
2526
from megatron.core.inference.contexts import BaseInferenceContext
2627
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
@@ -32,7 +33,6 @@
3233
from megatron.core.transformer.mlp import MLP, MLPSubmodules
3334
from torch import Tensor, nn
3435

35-
from nemo.collections.llm.fn.activation import openai_gelu
3636
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
3737
from nemo.collections.llm.gpt.model.gemma2 import TERowParallelLinearLayerNorm
3838
from nemo.collections.llm.utils import Config
@@ -146,16 +146,19 @@ class Gemma3Config(GPTConfig):
146146
attention_backend: AttnBackend = AttnBackend.flash
147147

148148
# mlp
149+
bias_activation_fusion: bool = True
149150
gated_linear_unit: bool = True
150151
add_bias_linear: bool = False
151-
activation_func: Callable = openai_gelu
152+
activation_func: Callable = torch.nn.functional.gelu
152153

153154
# Do not change
154155
is_vision_language: bool = False
155156
flash_decode: bool = False
156157
gradient_accumulation_fusion: bool = False
157158
transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = gemma3_layer_spec
158159
scatter_embedding_sequence_parallel: bool = True
160+
apply_rope_fusion: bool = True
161+
cross_entropy_fusion_impl: str = 'te'
159162

160163
def configure_model(
161164
self,
@@ -338,7 +341,12 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -
338341
"""Get global and local rope embedding"""
339342
rope_global = super().forward(max_seq_len, offset, packed_seq)
340343
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq)
341-
return rope_local, rope_global
344+
# when using recompute_granularity is full, save_for_backward is called
345+
# to save all variables in a layer. It can only save variables but not
346+
# tuples.
347+
# Stack rope_local and rope_global into a single tensor to avoid the
348+
# error.
349+
return torch.stack((rope_local, rope_global), dim=0)
342350

343351

344352
def _is_local_attn_layer(
@@ -372,7 +380,6 @@ def forward(
372380
inference_params: Optional[BaseInferenceContext] = None,
373381
) -> Tuple[Tensor, Tensor]:
374382
"""Switch to either local or global rope embedding before forward"""
375-
assert isinstance(rotary_pos_emb, tuple)
376383
assert rotary_pos_cos is None and rotary_pos_sin is None
377384

378385
if _is_local_attn_layer(self.layer_number, self.config.interleaved_attn_pattern):
@@ -614,6 +621,7 @@ def config(self):
614621
architectures=["Gemma3ForCausalLM"],
615622
num_hidden_layers=source.num_layers,
616623
hidden_size=source.hidden_size,
624+
sliding_window=source.window_size,
617625
intermediate_size=source.ffn_hidden_size,
618626
num_attention_heads=source.num_attention_heads,
619627
head_dim=source.kv_channels,

nemo/collections/vlm/gemma3vl/data/mock.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,19 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
186186
generation_prompt_size = 5 # "ASSISTANT:" like
187187
prompt_end_idx = img_start_idx + IMAGE_TOKENS + generation_prompt_size
188188
labels[:prompt_end_idx] = -100
189+
# Add the labels clipping to the mock data loader.
189190
labels = labels[1:]
190191

192+
# 5) prepare loss masks
193+
# Calculate loss mask from labels, to be consistent with real data and reduce confusions.
194+
loss_mask = torch.ones_like(labels, dtype=torch.float)
195+
loss_mask[labels < 0] = 0.0
196+
191197
return {
192198
"input_ids": input_ids,
193199
"position_ids": position_ids,
194200
"pixel_values": pixel_values,
195-
"loss_mask": self.loss_mask,
201+
"loss_mask": loss_mask,
196202
"labels": labels,
197203
}
198204

nemo/collections/vlm/gemma3vl/data/task_encoder.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import logging
1617
from dataclasses import dataclass, field
1718
from typing import Optional
@@ -24,6 +25,7 @@
2425
from nemo.collections.vlm.data.task_encoder import TaskEncoder as BaseTaskEncoder
2526
from nemo.collections.vlm.data.task_encoder import TaskEncoderConfig as BaseTaskEncoderConfig
2627
from nemo.collections.vlm.data.utils import _find_pattern_indices
28+
from nemo.utils import logging
2729

2830

2931
@dataclass
@@ -101,58 +103,65 @@ def encode_batch(self, batch_data: DataBatch) -> dict:
101103
batch_data["media"] = batch_data["media"].reshape(-1, *batch_data["media"].shape[2:])
102104
return batch_data
103105

104-
def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
105-
"""Encode a VQA sample into a DataSample format.
106+
def encode_vqa_sample_multi_turns(self, input_sample: VQASample):
107+
"""Encode a VQA sample multi turns into a DataSample format.
106108
107109
Args:
108110
input_sample (VQASample): Input VQA sample containing image, context and answers
109111
110112
Returns:
111-
DataSample: Encoded sample with processed image, tokens, labels and loss mask
113+
Encoded tokens, labels and images.
112114
"""
115+
images = input_sample.image if isinstance(input_sample.image, list) else [input_sample.image]
116+
117+
contexts = json.loads(input_sample.context.decode('utf-8'))
113118
messages = []
114119
if self.config.system_prompt:
115120
messages.append({'role': 'system', 'content': self.config.system_prompt})
116-
117-
# Ensure context and answers are lists for consistent processing
118-
contexts = input_sample.context if isinstance(input_sample.context, list) else [input_sample.context]
119-
answers = input_sample.answers if isinstance(input_sample.answers, list) else [input_sample.answers]
120-
121-
# Build the conversation messages, replacing image placeholder
122-
min_length = min(len(contexts), len(answers))
123-
for i in range(min_length):
124-
context_with_placeholder = contexts[i].replace("<image>", self.config.image_token)
125-
messages.append({'role': self.config.roles[0], 'content': context_with_placeholder})
126-
messages.append({'role': self.config.roles[1], 'content': answers[i]})
121+
for context in contexts:
122+
messages.append(context)
127123

128124
# Apply chat template and process with HF processor
129-
converted_messages = self.hf_processor.apply_chat_template(messages, tokenize=False)
125+
# `add_generation_prompt=False` because we're providing the full ground truth sequence
126+
# We remove the <bos> token using removeprefix('<bos>') since we're finetuning.
127+
# The Processor will add this token before training and the model expects only one.
128+
converted_messages = self.hf_processor.apply_chat_template(
129+
messages, add_generation_prompt=False, tokenize=False
130+
).removeprefix('<bos>')
130131
outputs = self.hf_processor(
131-
images=input_sample.image,
132+
images=images,
132133
text=converted_messages,
133134
return_tensors="pt",
134135
images_kwargs={"do_rescale": False},
135136
)
136-
137137
# Get tokens and images from processor output
138138
# Squeeze the batch dimension as we process one sample at a time
139139
tokens = outputs["input_ids"].squeeze(0)
140140
images = outputs.get("pixel_values") # Use .get() for optional images
141141

142142
# --- Label Generation ---
143+
# Same as: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/vlm/qwen2vl/data/task_encoder.py#L263-L270
143144
# Initialize labels with ignore placeholder
144145
labels = torch.full_like(tokens, self.config.ignore_place_holder)
145-
146146
search_start_index = 0
147-
for answer in answers:
147+
for context in contexts:
148+
if context['role'] != 'assistant':
149+
continue
148150
# Tokenize the answer, including the stop string if provided
149-
answer_with_stop = answer + (self.config.stop_string or "")
151+
answer_with_stop = (
152+
context['content'][0]['text'].rstrip().lstrip() + "<end_of_turn>" + (self.config.stop_string or "")
153+
)
154+
answer_with_stop = answer_with_stop.rstrip().lstrip()
150155
answer_tokens = self.tokenizer.tokenizer(answer_with_stop, add_special_tokens=False)["input_ids"]
151156
answer_tokens_tensor = torch.tensor(answer_tokens, device=tokens.device) # Ensure same device
152157

158+
# sometimes the tokenizer can add additional space. See:
159+
# https://github.com/huggingface/transformers/issues/25073#issuecomment-1655271420
160+
if self.tokenizer.tokenizer.decode(answer_tokens[0]) == "":
161+
answer_tokens_tensor = answer_tokens_tensor[1:]
162+
153163
# Find answer pattern in tokens
154164
answer_start, answer_end = _find_pattern_indices(tokens, answer_tokens_tensor, search_start_index)
155-
156165
if answer_start >= 0:
157166
labels[answer_start:answer_end] = tokens[answer_start:answer_end]
158167
search_start_index = answer_end
@@ -170,11 +179,24 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
170179
search_start_index,
171180
)
172181
break
182+
return tokens, labels, images
183+
184+
def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
185+
"""Encode a VQA sample into a DataSample format.
186+
187+
Args:
188+
input_sample (VQASample): Input VQA sample containing image, context and answers
189+
190+
Returns:
191+
DataSample: Encoded sample with processed image, tokens, labels and loss mask
192+
"""
193+
tokens, labels, images = self.encode_vqa_sample_multi_turns(input_sample)
173194

174195
# Prepare final tensors
175196
tokens = tokens[:-1].contiguous()
176197
labels = labels[1:].contiguous()
177198
seqlen = len(tokens) # Original sequence length before padding
199+
position_ids = torch.arange(seqlen, dtype=torch.int64)
178200

179201
# Pad tokens and labels to a multiple of `pad_to_multiple_of` if specified
180202
if self.config.pad_to_multiple_of:
@@ -191,7 +213,7 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
191213

192214
# Compute loss mask
193215
loss_mask = torch.ones_like(labels, dtype=torch.float)
194-
loss_mask[labels == self.config.ignore_place_holder] = 0.0
216+
loss_mask[labels < 0] = 0.0
195217

196218
# Convert images to bfloat16 and stack, or create an empty tensor if no images
197219
if images is not None and images.numel() > 0:
@@ -202,13 +224,16 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
202224
# Create an empty tensor with appropriate dimensions and dtype if no images
203225
processed_image = None
204226

205-
return Gemma3DataSample(
227+
sample = Gemma3DataSample(
206228
__key__=input_sample.__key__,
207229
__restore_key__=input_sample.__restore_key__,
208230
__subflavor__=input_sample.__subflavor__,
209231
__subflavors__=input_sample.__subflavors__,
210232
pixel_values=processed_image,
211233
input_ids=tokens,
234+
position_ids=position_ids,
212235
labels=labels,
213236
loss_mask=loss_mask,
214237
)
238+
239+
return sample

nemo/collections/vlm/gemma3vl/model/base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from nemo.collections.llm.gpt.model.gemma3 import Gemma3Config
3838
from nemo.collections.vlm.gemma3vl.model.vision import Gemma3VLMultimodalProjectorConfig, Gemma3VLVisionConfig
3939
from nemo.collections.vlm.neva.model.base import MODEL_CONFIG_ATTR, NevaModel, restore_model_weights
40+
from nemo.collections.vlm.qwen2vl.data.multimodal_tokens import IGNORE_INDEX
4041
from nemo.lightning import io
4142
from nemo.lightning.pytorch.optim import OptimizerModule
4243
from nemo.utils.import_utils import safe_import_from
@@ -78,6 +79,7 @@ def gemma3vl_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
7879
key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None
7980
for key, val in _batch.items()
8081
}
82+
8183
return _batch
8284

8385

@@ -392,18 +394,18 @@ def forward(
392394
input_ids = F.pad(input_ids, (0, padded_seq_len))
393395
position_ids = F.pad(position_ids, (0, padded_seq_len))
394396
if self.post_process:
395-
labels = F.pad(labels, (0, padded_seq_len))
396-
loss_mask = F.pad(loss_mask, (0, padded_seq_len))
397+
labels = F.pad(labels, (0, padded_seq_len), value=IGNORE_INDEX)
398+
loss_mask = F.pad(loss_mask, (0, padded_seq_len), value=0.0)
397399

398400
# Compute language embedding
399401
if self.pre_process:
400402
safe_input_ids = input_ids
401403
# Replace image_token_id with 0 to avoid embedding index error
402-
if self.image_token_id >= self.vocab_size:
403-
image_token_mask = input_ids == self.image_token_id
404-
safe_input_ids = input_ids.clone()
405-
safe_input_ids[image_token_mask] = 0
404+
image_token_mask = input_ids == self.image_token_id
405+
safe_input_ids = input_ids.clone()
406+
safe_input_ids[image_token_mask] = 0
406407
# (T, B, D)
408+
# The position_ids is None for qwen2 models, but set to position_ids for gemma3vl models.
407409
language_embedding = self.language_model.embedding(input_ids=safe_input_ids, position_ids=position_ids)
408410
# (B, T, D)
409411
language_embedding = language_embedding.transpose(1, 0).contiguous()
@@ -428,6 +430,7 @@ def forward(
428430
combined_embedding = combined_embedding.transpose(1, 0).contiguous()
429431

430432
# Run decoder model
433+
# position_ids is None for gemma3vl models, but set to position_ids to qwen2 models.
431434
output = self.language_model(
432435
input_ids=None,
433436
position_ids=None,
@@ -441,6 +444,8 @@ def forward(
441444

442445
if labels is None or loss_mask is None:
443446
return output
447+
448+
output = output.masked_fill(labels < 0, 0.0)
444449
return output, loss_mask
445450

446451
def _preprocess_data(
@@ -536,6 +541,7 @@ def _process_sequence_parallel(
536541
combined_embedding = scatter_to_sequence_parallel_region(combined_embedding)
537542
return combined_embedding, labels, loss_mask, packed_seq_params
538543

544+
@torch.compile
539545
def _compute_attention_mask(
540546
self,
541547
input_ids: torch.Tensor,

scripts/vlm/gemma3vl_export.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Export Gemma3VL NeMo checkpoints to Hugging Face format."""
15+
16+
import argparse
17+
from pathlib import Path
18+
19+
from huggingface_hub import hf_hub_download
20+
21+
from nemo.collections import llm
22+
23+
24+
def main():
25+
parser = argparse.ArgumentParser(
26+
description=("Export NeMo vision language model checkpoint to Hugging Face format.")
27+
)
28+
parser.add_argument(
29+
"--nemo_ckpt_path",
30+
type=str,
31+
required=True,
32+
default=None,
33+
help="Path to the NeMo checkpoint directory.",
34+
)
35+
parser.add_argument(
36+
"--output_hf_path",
37+
type=str,
38+
required=True,
39+
default=None,
40+
help="Path to save the converted Hugging Face checkpoint.",
41+
)
42+
parser.add_argument(
43+
"--model_name",
44+
type=str,
45+
required=False,
46+
default=None,
47+
help="Name of the model on Hugging Face.",
48+
)
49+
50+
args = parser.parse_args()
51+
52+
llm.export_ckpt(
53+
path=Path(args.nemo_ckpt_path),
54+
target="hf",
55+
output_path=Path(args.output_hf_path),
56+
overwrite=True,
57+
)
58+
if args.model_name:
59+
# Copy necessary files if exist from HuggingFace for Gemma3VL model export.
60+
copy_file_list = [
61+
"preprocessor_config.json",
62+
"chat_template.json",
63+
"config.json",
64+
"generation_config.json",
65+
"merges.txt",
66+
"tokenizer.json",
67+
"tokenizer_config.json",
68+
"vocab.json",
69+
]
70+
for file_name in copy_file_list:
71+
try:
72+
downloaded_path = hf_hub_download(
73+
repo_id=args.model_name,
74+
filename=file_name,
75+
local_dir=args.output_hf_path,
76+
)
77+
print(f"Downloaded {downloaded_path} during export gamma3vl models.")
78+
except Exception as e:
79+
print(f"Ignore {file_name} during export gamma3vl models.")
80+
81+
82+
if __name__ == "__main__":
83+
main()

0 commit comments

Comments
 (0)