You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@pytest.mark.skipif(notTORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default")
Copy file name to clipboardExpand all lines: torchao/prototype/low_bit_optim/README.md
+18-28Lines changed: 18 additions & 28 deletions
Original file line number
Diff line number
Diff line change
@@ -27,45 +27,35 @@ NOTE:
27
27
- The low-bit optimizers require PyTorch >= 2.3
28
28
- For FP8 optimizers on CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required.
29
29
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
30
-
- The first training step is expected to be slow since the optimizer needs to be compiled.
31
30
32
31
## Benchmarks
33
32
34
-
Fine-tune [timm](https://github.com/huggingface/pytorch-image-models)'s ViT-H (630M params) on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset. BF16 AMP, 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed. Benchmark script is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
33
+
Fine-tune [timm](https://github.com/huggingface/pytorch-image-models)'s [ViT-H](https://huggingface.co/timm/vit_huge_patch14_224.orig_in21k) (630M params) on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset. PyTorch 2.4, BF16 AMP, compiled model, 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed. Benchmark script is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.
47
46
48
-
Fine-tune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. Full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with [torchtune](https://github.com/pytorch/torchtune). See [#746](https://github.com/pytorch/ao/pull/746) for more details.
47
+
Fine-tune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. PyTorch 2.4, full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with [torchtune 52d1b838](https://github.com/pytorch/torchtune/tree/52d1b838c1c35b5e75fddf8776be400adc36dff5). See [#812](https://github.com/pytorch/ao/pull/812) for more details.
49
48
50
-
AdamW impl | Max memory (GB) | toks/s | `truthfulqa_mc2` acc | Compile time
NOTE: lpmm's 4-bit AdamW does not support BF16 weights.
59
58
60
-
### Note on compile times
61
-
62
-
There are 2 approaches to compile optimizer step in low-bit optim:
63
-
64
-
1. Compile optim step for single param i.e. `torch.compile(single_param_adam)`
65
-
2. Compile optim step for all params i.e. `torch.compile(param_groups_adam)`
66
-
67
-
Currently Adam8bit and AdamFp8 use approach (2) (with static shape) since it is faster (but compile much slower), while Adam4bit uses approach (1) (with dynamic shape) since there are excessive memory usage for "Adam4bit + approach (2)". Approach (1) requires dynamic shape to avoid hitting recompiles limit.
68
-
69
59
## Optimizer CPU offload
70
60
71
61
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.
0 commit comments