Gemma4 Assistant model support for MTP #2481
athitten
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
NeMo-Automodel recently added support for Gemma4 Assistant model (also called drafter model) targeted towards faster inference with Multi-Token Prediction (MTP). These are separate models (for ex: google/gemma-4-E4B-it-assistant) which are significantly smaller (~70M params) than the base/target model and help predict multiple tokens in the same time the base model predicts one token.
NeMo-Automodel specifically adds support for joint fine-tuning of Gemma 4 base and Gemma4 drafter/assistant models so that end-users can fine-tune both of them jointly on custom data and be able to leverage the benefits of speculative decoding at the time of inference on their model.
Our experiments have shown ~1.98x speed-up with MTP compared to without MTP at the time of inference post joint fine-tuning.
The drafter is co-trained with the Gemma 4 base end-to-end via a composite model (
Gemma4WithDrafter) that wires up shared K/V states, sqrt(H_b)-scaled embeddings, and a K-step recurrent forward matching the Gemma 4 drafter tech report.Joint Finetuning Recipes
We provide two reference configs, for joint fine-tuning of gemma-4-E4B-it and gemma-4-E4B-it-assistant: one with MedPix VQA dataset: gemma4_4b_joint_drafter_medpix.yaml and the other with a text-only Tulu-3 + Magicoder mix which is a larger dataset than MedPix-VQA: gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml. We also provide an inference benchmark script that validates speculative-decode throughput on the saved checkpoint: benchmark_mtp_inference.py
Results
examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_medpix.yamlLoss curve:
examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_tulu_magicoder_mix.yamlLoss curve:
benchmark_mtp_inference.py(uses transformers generate) against the savedbase/+drafter/pair. Results below:MTP support for other models in NeMo-Automodel
While Gemma4 assistant is a separate model for which NeMo-Automodel provides scaffolding to jointly fine-tune with the base model, we also support MTP for models that inherently have MTP layers in them like Nemotron V3, DeepSeek V4 (Flash), Qwen3.6 dense and MoE. Check out the respective model recipes to enable them.
Thanks to @adil-a @khazic @HuiyingLi for this!!
Beta Was this translation helpful? Give feedback.
All reactions