Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1b0918f
[WIP] Initial Llama4 implementation
ebsmothers Apr 7, 2025
2717a49
small changes
ebsmothers Apr 8, 2025
e6328e9
Merge branch 'main' into initial-llama4
ebsmothers Apr 8, 2025
d9b207e
sigmoid only for moe
ebsmothers Apr 8, 2025
2022dff
generation works
ebsmothers Apr 10, 2025
f8e3a73
philip's changes
ebsmothers Apr 10, 2025
0bce983
minor cleanup
ebsmothers Apr 10, 2025
f918aa3
fix weight update, copy-paste TP API to minimize nightly dep
ebsmothers Apr 10, 2025
32ebae3
salman's comments
ebsmothers Apr 10, 2025
8b14989
small rope and rms refactors, couple more comments
ebsmothers Apr 10, 2025
31b9de0
fix key name in tune_to_hf
ebsmothers Apr 10, 2025
b718e22
Use the correct mapping for tune_to_hf
joecummings Apr 10, 2025
df18d23
Rework Scout training config
joecummings Apr 10, 2025
b2b03ad
Add back in profiler IG
joecummings Apr 10, 2025
2f57265
Comments
joecummings Apr 10, 2025
b9d7866
Comments
joecummings Apr 10, 2025
cdd9e83
Rework saving logic to HF
joecummings Apr 10, 2025
8c79661
Make sure the default case is handled in conversion
joecummings Apr 10, 2025
c938fa9
Update Maverick and Scout generation configs
joecummings Apr 10, 2025
aa1dcd7
Updated encoder from 3.2 to 4
pbontrager Apr 10, 2025
15777a4
rafi's comments
ebsmothers Apr 10, 2025
8574c71
fixed unit tests
pbontrager Apr 10, 2025
97025a6
readme updates
ebsmothers Apr 10, 2025
4986756
fix empy encoder inputs error
ebsmothers Apr 10, 2025
a7e0596
remove unnecessary config reads from checkpointer
ebsmothers Apr 10, 2025
70f01b6
typo
ebsmothers Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[**Overview**](#overview-) | [**Installation**](#installation-%EF%B8%8F) | [**Get Started**](#get-started-) | [**Documentation**](https://pytorch.org/torchtune/main/index.html) | [**Community**](#community-) | [**Citing torchtune**](#citing-torchtune-) | [**License**](#license)

### 📣 Recent updates 📣
* *April 2025*: Llama4 is now available in torchtune! Try out our full finetuning configs [here](recipes/configs/llama4) (LoRA coming soon!)
* *February 2025*: Multi-node training is officially [open for business in torchtune](https://pytorch.org/torchtune/main/tutorials/multinode.html)! Full finetune on multiple nodes to take advantage of larger batch sizes and models.
* *December 2024*: torchtune now supports **Llama 3.3 70B**! Try it out by following our installation instructions [here](#installation-%EF%B8%8F), then run any of the configs [here](recipes/configs/llama3_3).
* *November 2024*: torchtune has released [v0.4.0](https://github.com/pytorch/torchtune/releases/tag/v0.4.0) which includes stable support for exciting features like activation offloading and multimodal QLoRA
Expand Down Expand Up @@ -90,6 +91,7 @@ For the above recipes, torchtune supports many state-of-the-art models available

| Model | Sizes |
|-----------------------------------------------|-----------|
| [Llama4](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4) | Scout (17B x 16E) [[models](torchtune/models/llama4/_model_builders.py), [configs](recipes/configs/llama4/)] |
| [Llama3.3](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_3) | 70B [[models](torchtune/models/llama3_3/_model_builders.py), [configs](recipes/configs/llama3_3/)] |
| [Llama3.2-Vision](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-vision-models-(11b/90b)-) | 11B, 90B [[models](torchtune/models/llama3_2_vision/_model_builders.py), [configs](recipes/configs/llama3_2_vision/)] |
| [Llama3.2](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2) | 1B, 3B [[models](torchtune/models/llama3_2/_model_builders.py), [configs](recipes/configs/llama3_2/)] |
Expand Down
33 changes: 33 additions & 0 deletions docs/source/api_ref_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,39 @@ torchtune.models

.. currentmodule:: torchtune.models

llama4
------

Multimodal models from the Llama4 family that support text and image input.

Important: You need to request access on Hugging Face before downloading it.

To download the Llama-4-Maverick-17B-16E-Instruct model:

.. code-block:: bash

tune download meta-llama/Llama-4-Scout-17B-16E-Instruct --hf-token <HF_TOKEN>

To download the Llama-4-Maverick-17B-128E-Instruct model:

.. code-block:: bash

tune download meta-llama/Llama-4-Maverick-17B-128E-Instruct --hf-token <HF_TOKEN>

.. autosummary::
:toctree: generated/
:nosignatures:

llama4.llama4_scout_17b_16e
llama4.llama4_maverick_17b_128e
llama4.llama4_vision_encoder
llama4.llama4_vision_projection_head
llama4.llama4_decoder
llama4.Llama4VisionEncoder
llama4.Llama4VisionProjectionHead
llama4.Llama4Tokenizer
llama4.Llama4Transform

llama3.3
--------

Expand Down
86 changes: 86 additions & 0 deletions recipes/configs/llama4/maverick_17B_128E_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Llama4 17Bx128E MoE model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-4-Maverick-17B-128E-Instruct
#
# Full finetuning of Llama4 17Bx128E is only possible in a multi-node setting
# An example slurm script can be found under recipes/full_finetune_multinode.slurm.
# Example usage:
# sbatch full_finetune_multinode.slurm
#
# This config is only tested on 2 8xH100 nodes.

output_dir: /tmp/torchtune/llama4_17Bx128E/full

# Model arguments
model:
_component_: torchtune.models.llama4.llama4_maverick_17b_128e

tensor_parallel_dim: 8
tensor_parallel_plan:
_component_: torchtune.models.llama4.decoder_only_tp_plan
data_parallel_shard_dim: -1 # Will infer based on TP dim, effictively controls FSDP
data_parallel_replicate_dim: 1

tokenizer:
_component_: torchtune.models.llama4.llama4_transform
path: ${model_dir}/tokenizer.model
max_seq_len: null
max_num_tiles: 16

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-4-Maverick-17B-128E-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00055"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA4
resume_from_checkpoint: False

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
packed: False
seed: null
shuffle: True

# Training arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 1
gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
fused: False
optimizer_in_bwd: False
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: null

# cuda, cpu, rocm, xpu...
device: cuda

# Memory management / performance
enable_activation_checkpointing: True
enable_activation_offloading: False
fsdp_cpu_offload: True
compile: False # torch.compile, set to true for perf/memory improvement

# Reduced precision
dtype: bf16

# Log metrics during training
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Useful for understanding how to optimize memory and performance
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
87 changes: 87 additions & 0 deletions recipes/configs/llama4/scout_17B_16E_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Llama4 17Bx16E MoE model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-4-Scout-17B-16E-Instruct
#
# To launch on 8 devices, run the following command from root:
# tune run --nproc_per_node 8 full_finetune_distributed --config llama4/scout_17B_16E_full
#
# You can add specific overrides through the command line. For example, to use a larger bsz:
# tune run --nproc_per_node 8 full_finetune_distributed --config llama4/scout_17B_16E_full batch_size=8
#
# This config was only tested on 8xA100 machine and 16xH100 machines.

output_dir: /tmp/torchtune/llama4_17Bx16E/full

# Modeling arguments
model:
_component_: torchtune.models.llama4.llama4_scout_17b_16e

tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8
tensor_parallel_plan:
_component_: torc/home/ebs/l4_final_test.outhtune.models.llama4.decoder_only_tp_plan
data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP
data_parallel_replicate_dim: 1

tokenizer:
_component_: torchtune.models.llama4.llama4_transform
path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model
max_seq_len: null
max_num_tiles: 16

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00050"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA4
resume_from_checkpoint: False

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
packed: False
seed: null
shuffle: True

# Training arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 1
gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
fused: False
optimizer_in_bwd: False
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: null

# cuda, cpu, rocm, xpu...
device: cuda

# Memory management / performance
enable_activation_checkpointing: True
enable_activation_offloading: False
fsdp_cpu_offload: True
compile: False # torch.compile, set to true for perf/memory improvement

# Reduced precision
dtype: bf16

# Log metrics during training
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Useful for understanding how to optimize memory and performance
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
48 changes: 48 additions & 0 deletions recipes/configs/llama4/scout_17B_16E_generation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Config for running the InferenceRecipe in dev/generate_v2.py to generate output
# from a Llama4 17Bx16E MoE model
#
# This config assumes that you've run the following command before launching
# tune download meta-llama/Llama-4-Scout-17B-16E-Instruct
#
# To launch, run the following command:
# tune run --nproc_per_node 4 dev/generate_v2_distributed --config llama4/scout_17B_16E_generation_distributed

# Model arguments
model:
_component_: torchtune.models.llama4.llama4_scout_17b_16e

tensor_parallel_plan:
_component_: torchtune.models.llama4.decoder_only_tp_plan

tokenizer:
_component_: torchtune.models.llama4.llama4_transform
path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model
max_seq_len: null
max_num_tiles: 16

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct # You can also point this to your finetuned model!
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00050"
output_dir: ./ # No need for an output dir
model_type: LLAMA4

use_distributed_state_dict: True
use_flex: True # Use PyTorch's FlexAttention for construction of attention masks

# Environment
device: cuda
dtype: bf16
seed: 1234
log_level: INFO

# Generation arguments
prompt:
system: You are a helpful assistant.
user:
text: How are you doing?
max_new_tokens: 200
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
13 changes: 8 additions & 5 deletions recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def setup(self, cfg: DictConfig) -> None:
model = config.instantiate(cfg.model)
model.load_state_dict(_ckpt_dict[training.MODEL_KEY])
self.model = model
self._logger.info(f"Model was initialized with precision {self._dtype}.")
self.model.eval()
self._logger.info(
f"Model was initialized with precision {self._dtype} and put into eval mode."
)

# Instantiate transforms
self.model_transform = config.instantiate(cfg.tokenizer)
Expand Down Expand Up @@ -128,7 +131,7 @@ def generate(self, cfg: DictConfig):
"""The main entry point for generating tokens from a prompt."""
# 1. Convert input to messages
messages = self.to_messages(OmegaConf.to_container(cfg.prompt))
is_multimodal_input = any([m.contains_media for m in messages])
is_image_input = any([m.contains_media for m in messages])

# 2. Apply model transform
model_inputs = self.model_transform({"messages": messages}, inference=True)
Expand All @@ -141,7 +144,7 @@ def generate(self, cfg: DictConfig):
batch_size=1,
dtype=self._dtype,
encoder_max_seq_len=(
self.model_transform.image_seq_len if is_multimodal_input else None
self.model_transform.image_seq_len if is_image_input else None
),
decoder_max_seq_len=total_response_length,
)
Expand All @@ -158,7 +161,7 @@ def generate(self, cfg: DictConfig):

# 5. Collate to batch size of 1 and tensor-ify
batch = {}
if is_multimodal_input:
if is_image_input:
batch = padded_collate_tiled_images_and_mask(
[model_inputs],
pad_direction="left",
Expand All @@ -182,7 +185,7 @@ def generate(self, cfg: DictConfig):
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
generated_tokens.append(token.item())

if is_multimodal_input:
if is_image_input:
# Don't need image info b/c we only support 1 image and it's been
# processed by the model now
batch.pop("encoder_input")
Expand Down
Loading
Loading