-
Notifications
You must be signed in to change notification settings - Fork 693
🦙🦙🦙🦙 Llama4 in torchtune 🦙🦙🦙🦙 #2570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
1b0918f
[WIP] Initial Llama4 implementation
ebsmothers 2717a49
small changes
ebsmothers e6328e9
Merge branch 'main' into initial-llama4
ebsmothers d9b207e
sigmoid only for moe
ebsmothers 2022dff
generation works
ebsmothers f8e3a73
philip's changes
ebsmothers 0bce983
minor cleanup
ebsmothers f918aa3
fix weight update, copy-paste TP API to minimize nightly dep
ebsmothers 32ebae3
salman's comments
ebsmothers 8b14989
small rope and rms refactors, couple more comments
ebsmothers 31b9de0
fix key name in tune_to_hf
ebsmothers b718e22
Use the correct mapping for tune_to_hf
joecummings df18d23
Rework Scout training config
joecummings b2b03ad
Add back in profiler IG
joecummings 2f57265
Comments
joecummings b9d7866
Comments
joecummings cdd9e83
Rework saving logic to HF
joecummings 8c79661
Make sure the default case is handled in conversion
joecummings c938fa9
Update Maverick and Scout generation configs
joecummings aa1dcd7
Updated encoder from 3.2 to 4
pbontrager 15777a4
rafi's comments
ebsmothers 8574c71
fixed unit tests
pbontrager 97025a6
readme updates
ebsmothers 4986756
fix empy encoder inputs error
ebsmothers a7e0596
remove unnecessary config reads from checkpointer
ebsmothers 70f01b6
typo
ebsmothers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # Config for multi-device full finetuning in full_finetune_distributed.py | ||
| # using a Llama4 17Bx128E MoE model | ||
| # | ||
| # This config assumes that you've run the following command before launching: | ||
| # tune download meta-llama/Llama-4-Maverick-17B-128E-Instruct | ||
| # | ||
| # Full finetuning of Llama4 17Bx128E is only possible in a multi-node setting | ||
| # An example slurm script can be found under recipes/full_finetune_multinode.slurm. | ||
| # Example usage: | ||
| # sbatch full_finetune_multinode.slurm | ||
| # | ||
| # This config is only tested on 2 8xH100 nodes. | ||
|
|
||
| output_dir: /tmp/torchtune/llama4_17Bx128E/full | ||
|
|
||
| # Model arguments | ||
| model: | ||
| _component_: torchtune.models.llama4.llama4_maverick_17b_128e | ||
|
|
||
| tensor_parallel_dim: 8 | ||
| tensor_parallel_plan: | ||
| _component_: torchtune.models.llama4.decoder_only_tp_plan | ||
| data_parallel_shard_dim: -1 # Will infer based on TP dim, effictively controls FSDP | ||
| data_parallel_replicate_dim: 1 | ||
|
|
||
| tokenizer: | ||
| _component_: torchtune.models.llama4.llama4_transform | ||
| path: ${model_dir}/tokenizer.model | ||
| max_seq_len: null | ||
| max_num_tiles: 16 | ||
|
|
||
| checkpointer: | ||
| _component_: torchtune.training.FullModelHFCheckpointer | ||
| checkpoint_dir: /tmp/Llama-4-Maverick-17B-128E-Instruct | ||
| checkpoint_files: | ||
| filename_format: model-{}-of-{}.safetensors | ||
| max_filename: "00055" | ||
| recipe_checkpoint: null | ||
| output_dir: ${output_dir} | ||
| model_type: LLAMA4 | ||
| resume_from_checkpoint: False | ||
|
|
||
| # Dataset | ||
| dataset: | ||
| _component_: torchtune.datasets.alpaca_dataset | ||
| packed: False | ||
| seed: null | ||
| shuffle: True | ||
|
|
||
| # Training arguments | ||
| epochs: 1 | ||
| max_steps_per_epoch: null | ||
| batch_size: 1 | ||
| gradient_accumulation_steps: 1 # Use to increase effective batch size | ||
| optimizer: | ||
| _component_: torch.optim.AdamW | ||
| lr: 2e-5 | ||
| fused: False | ||
| optimizer_in_bwd: False | ||
| loss: | ||
| _component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
| clip_grad_norm: null | ||
|
|
||
| # cuda, cpu, rocm, xpu... | ||
| device: cuda | ||
|
|
||
| # Memory management / performance | ||
| enable_activation_checkpointing: True | ||
| enable_activation_offloading: False | ||
| fsdp_cpu_offload: True | ||
| compile: False # torch.compile, set to true for perf/memory improvement | ||
|
|
||
| # Reduced precision | ||
| dtype: bf16 | ||
|
|
||
| # Log metrics during training | ||
| metric_logger: | ||
| _component_: torchtune.training.metric_logging.DiskLogger | ||
| log_dir: ${output_dir}/logs | ||
| log_every_n_steps: 1 | ||
| log_peak_memory_stats: True | ||
|
|
||
| # Useful for understanding how to optimize memory and performance | ||
| profiler: | ||
| _component_: torchtune.training.setup_torch_profiler | ||
| enabled: False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # Config for multi-device full finetuning in full_finetune_distributed.py | ||
| # using a Llama4 17Bx16E MoE model | ||
| # | ||
| # This config assumes that you've run the following command before launching: | ||
| # tune download meta-llama/Llama-4-Scout-17B-16E-Instruct | ||
| # | ||
| # To launch on 8 devices, run the following command from root: | ||
| # tune run --nproc_per_node 8 full_finetune_distributed --config llama4/scout_17B_16E_full | ||
| # | ||
| # You can add specific overrides through the command line. For example, to use a larger bsz: | ||
| # tune run --nproc_per_node 8 full_finetune_distributed --config llama4/scout_17B_16E_full batch_size=8 | ||
| # | ||
| # This config was only tested on 8xA100 machine and 16xH100 machines. | ||
|
|
||
| output_dir: /tmp/torchtune/llama4_17Bx16E/full | ||
|
|
||
| # Modeling arguments | ||
| model: | ||
| _component_: torchtune.models.llama4.llama4_scout_17b_16e | ||
|
|
||
| tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8 | ||
| tensor_parallel_plan: | ||
| _component_: torc/home/ebs/l4_final_test.outhtune.models.llama4.decoder_only_tp_plan | ||
| data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP | ||
| data_parallel_replicate_dim: 1 | ||
|
|
||
| tokenizer: | ||
| _component_: torchtune.models.llama4.llama4_transform | ||
| path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model | ||
| max_seq_len: null | ||
| max_num_tiles: 16 | ||
|
|
||
| checkpointer: | ||
| _component_: torchtune.training.FullModelHFCheckpointer | ||
| checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct | ||
| checkpoint_files: | ||
| filename_format: model-{}-of-{}.safetensors | ||
| max_filename: "00050" | ||
| recipe_checkpoint: null | ||
| output_dir: ${output_dir} | ||
| model_type: LLAMA4 | ||
| resume_from_checkpoint: False | ||
|
|
||
| # Dataset | ||
| dataset: | ||
| _component_: torchtune.datasets.alpaca_dataset | ||
| packed: False | ||
| seed: null | ||
| shuffle: True | ||
|
|
||
| # Training arguments | ||
| epochs: 1 | ||
| max_steps_per_epoch: null | ||
| batch_size: 1 | ||
| gradient_accumulation_steps: 1 # Use to increase effective batch size | ||
| optimizer: | ||
| _component_: torch.optim.AdamW | ||
| lr: 2e-5 | ||
| fused: False | ||
| optimizer_in_bwd: False | ||
| loss: | ||
| _component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
| clip_grad_norm: null | ||
|
|
||
| # cuda, cpu, rocm, xpu... | ||
| device: cuda | ||
|
|
||
| # Memory management / performance | ||
| enable_activation_checkpointing: True | ||
| enable_activation_offloading: False | ||
| fsdp_cpu_offload: True | ||
| compile: False # torch.compile, set to true for perf/memory improvement | ||
|
|
||
| # Reduced precision | ||
| dtype: bf16 | ||
|
|
||
| # Log metrics during training | ||
| metric_logger: | ||
| _component_: torchtune.training.metric_logging.DiskLogger | ||
| log_dir: ${output_dir}/logs | ||
| log_every_n_steps: 1 | ||
| log_peak_memory_stats: True | ||
|
|
||
| # Useful for understanding how to optimize memory and performance | ||
| profiler: | ||
| _component_: torchtune.training.setup_torch_profiler | ||
| enabled: False |
48 changes: 48 additions & 0 deletions
48
recipes/configs/llama4/scout_17B_16E_generation_distributed.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| # Config for running the InferenceRecipe in dev/generate_v2.py to generate output | ||
| # from a Llama4 17Bx16E MoE model | ||
| # | ||
| # This config assumes that you've run the following command before launching | ||
| # tune download meta-llama/Llama-4-Scout-17B-16E-Instruct | ||
| # | ||
| # To launch, run the following command: | ||
| # tune run --nproc_per_node 4 dev/generate_v2_distributed --config llama4/scout_17B_16E_generation_distributed | ||
|
|
||
| # Model arguments | ||
| model: | ||
| _component_: torchtune.models.llama4.llama4_scout_17b_16e | ||
|
|
||
| tensor_parallel_plan: | ||
| _component_: torchtune.models.llama4.decoder_only_tp_plan | ||
|
|
||
| tokenizer: | ||
| _component_: torchtune.models.llama4.llama4_transform | ||
| path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model | ||
| max_seq_len: null | ||
| max_num_tiles: 16 | ||
|
|
||
| checkpointer: | ||
| _component_: torchtune.training.FullModelHFCheckpointer | ||
| checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct # You can also point this to your finetuned model! | ||
| checkpoint_files: | ||
| filename_format: model-{}-of-{}.safetensors | ||
| max_filename: "00050" | ||
| output_dir: ./ # No need for an output dir | ||
| model_type: LLAMA4 | ||
|
|
||
| use_distributed_state_dict: True | ||
| use_flex: True # Use PyTorch's FlexAttention for construction of attention masks | ||
|
|
||
| # Environment | ||
| device: cuda | ||
| dtype: bf16 | ||
| seed: 1234 | ||
| log_level: INFO | ||
|
|
||
| # Generation arguments | ||
| prompt: | ||
| system: You are a helpful assistant. | ||
| user: | ||
| text: How are you doing? | ||
| max_new_tokens: 200 | ||
| temperature: 0.6 # 0.8 and 0.6 are popular values to try | ||
| top_k: 300 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.