Skip to content

add flags to readme #2003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ torchtune currently supports the following models.

| Model | Sizes |
|-----------------------------------------------|-----------|
| [Llama3.2-Vision](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-vision-models-(11b/90b)-) | 11B [[models](torchtune/models/llama3_2_vision/_model_builders.py), [configs](recipes/configs/llama3_2_vision/)] |
| [Llama3.2-Vision](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-vision-models-(11b/90b)-) | 11B, 90B [[models](torchtune/models/llama3_2_vision/_model_builders.py), [configs](recipes/configs/llama3_2_vision/)] |
| [Llama3.2](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2) | 1B, 3B [[models](torchtune/models/llama3_2/_model_builders.py), [configs](recipes/configs/llama3_2/)] |
| [Llama3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) | 8B, 70B, 405B [[models](torchtune/models/llama3_1/_model_builders.py), [configs](recipes/configs/llama3_1/)] |
| [Llama3](https://llama.meta.com/llama3) | 8B, 70B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] |
Expand Down Expand Up @@ -103,6 +103,49 @@ If you are interested in running on different hardware or with different models,

 

### Optimization flags

torchtune exposes a number of levers for memory efficiency and performance. The table below demonstrates the effects of applying some of these techniques sequentially to the Llama 3.2 3B model. Each technique is added on top of the previous one, except for LoRA and QLoRA, which do not use `optimizer_in_bwd` or `AdamW8bit` optimizer.

**Baseline:**
- **Model:** Llama 3.2 3B
- **Batch size:** 2
- **Max seq len:** 4096
- **Precision:** bf16
- **Hardware:** A100
- **Recipe:** full_finetune_single_device

| Technique | Peak Memory Active (GiB) | % Change Memory vs Previous | Tokens Per Second | % Change Tokens/sec vs Previous|
|:--|:-:|:-:|:-:|:-:|
| Baseline | 25.5 | - | 2091 | - |
| [+ Packed Dataset](https://pytorch.org/torchtune/main/basics/packing.html) | 60.0 | +135.16% | 7075 | +238.40% |
| [+ Compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) | 51.0 | -14.93% | 8998 | +27.18% |
| [+ Chunked Cross Entropy](https://pytorch.org/torchtune/main/generated/torchtune.modules.loss.CEWithChunkedOutputLoss.html) | 42.9 | -15.83% | 9174 | +1.96% |
| [+ Activation Checkpointing](https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html#activation-checkpointing) | 24.9 | -41.93% | 7210 | -21.41% |
| [+ Fuse optimizer step into backward](https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html#fusing-optimizer-step-into-backward-pass) | 23.1 | -7.29% | 7309 | +1.38% |
| [+ Activation Offloading](https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html#activation-offloading) | 21.8 | -5.48% | 7301 | -0.11% |
| [+ 8-bit AdamW](https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html#lower-precision-optimizers) | 17.6 | -19.63% | 6960 | -4.67% |
| [LoRA](https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html#glossary-lora) | 8.5 | -51.61% | 8210 | +17.96% |
| [QLoRA](https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html#quantized-low-rank-adaptation-qlora) | 4.6 | -45.71% | 8035 | -2.13% |

The final row in the table vs baseline + Packed Dataset uses **81.9%** less memory with a **284.3%** increase in tokens per second. It can be run via the command:
```
tune run lora_finetune_single_device --config llama3_2/3B_qlora_single_device \
dataset.packed=True \
compile=True \
loss=torchtune.modules.loss.CEWithChunkedOutputLoss \
enable_activation_checkpointing=True \
optimizer_in_bwd=False \
enable_activation_offloading=True \
optimizer._component_=torch.optim.AdamW \
tokenizer.max_seq_len=4096 \
gradient_accumulation_steps=1 \
epochs=1 \
batch_size=2
```

 

## Installation

torchtune is tested with the latest stable PyTorch release as well as the preview nightly version. torchtune leverages
Expand Down
Loading