Commit 719dae9
committed
Enable FP8 full finetune distributed
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.
To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```
The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.
Based on #2404 by @nathan-az
**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:
```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```
For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:
```
experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved
---------------------- ------------------- ----------------- ---------------- -------------------
full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%)
full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%)
fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%)
fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%)
experiment_name hellaswag_acc wikitext_word_perplexity
---------------------- --------------- --------------------------
full 0.584 (+0.000) 9.419 (+0.000)
full_tp 0.584 (+0.000) 9.415 (-0.004)
fp8_noname 0.585 (+0.000) 9.431 (+0.012)
fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006)
fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002)
fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005)
fp8_rowwise 0.583 (-0.002) 9.421 (+0.002)
fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014)
```
A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`
For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:
```
experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved
---------------------- ------------------- ----------------- ---------------- -------------------
full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%)
full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%)
fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%)
fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%)
experiment_name hellaswag_acc wikitext_word_perplexity
---------------------- --------------- --------------------------
full 0.594 (+0.000) 9.087 (+0.000)
full_tp 0.594 (+0.001) 9.089 (+0.002)
fp8_noname 0.593 (-0.001) 9.070 (-0.017)
fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009)
fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026)
fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026)
fp8_rowwise 0.593 (-0.000) 9.086 (-0.001)
fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000)
```
**Test Plan:**
Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
enable_fp8_training=true \
fp8_recipe_name=tensorwise \
epochs=1 \
batch_size=16 \
compile=true \
dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
checkpointer.output_dir="$LOG_DIR" \
output_dir="${LOG_DIR}/metrics" \
metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)
Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_validate_float8_tp_plan
pytest tests -k test_is_fp8_tensorwise_scaling
```1 parent 0445bc2 commit 719dae9
File tree
5 files changed
+230
-30
lines changed- recipes
- tests/torchtune/training
- torchtune
- models/llama3
- training
5 files changed
+230
-30
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| 22 | + | |
22 | 23 | | |
23 | 24 | | |
24 | 25 | | |
| |||
33 | 34 | | |
34 | 35 | | |
35 | 36 | | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
36 | 42 | | |
37 | 43 | | |
38 | 44 | | |
| |||
184 | 190 | | |
185 | 191 | | |
186 | 192 | | |
| 193 | + | |
| 194 | + | |
187 | 195 | | |
188 | 196 | | |
189 | 197 | | |
| |||
545 | 553 | | |
546 | 554 | | |
547 | 555 | | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
548 | 565 | | |
549 | 566 | | |
550 | 567 | | |
| |||
851 | 868 | | |
852 | 869 | | |
853 | 870 | | |
| 871 | + | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
854 | 881 | | |
855 | 882 | | |
856 | 883 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
18 | | - | |
| 18 | + | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| 33 | + | |
33 | 34 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
| 7 | + | |
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
45 | 70 | | |
46 | 71 | | |
47 | 72 | | |
| |||
70 | 95 | | |
71 | 96 | | |
72 | 97 | | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
| 7 | + | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
10 | 11 | | |
11 | 12 | | |
12 | | - | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
13 | 21 | | |
14 | 22 | | |
15 | 23 | | |
16 | 24 | | |
17 | 25 | | |
18 | | - | |
19 | 26 | | |
20 | 27 | | |
21 | 28 | | |
| |||
26 | 33 | | |
27 | 34 | | |
28 | 35 | | |
| 36 | + | |
29 | 37 | | |
30 | 38 | | |
31 | 39 | | |
| |||
219 | 227 | | |
220 | 228 | | |
221 | 229 | | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
0 commit comments