Skip to content

Commit c0a81f9

Browse files
authored
[low-bit optim] Update docs on supported platforms and caveats (#971)
update
1 parent 96e8fee commit c0a81f9

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

torchao/prototype/low_bit_optim/README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This folder implements:
66
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
77
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)
88

9-
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.
9+
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. Thus, your platform must support `torch.compile()` to use these optimizers. We only test on CPU and CUDA, so there might be bugs or errors on other platforms.
1010

1111
## Usage
1212

@@ -58,7 +58,7 @@ NOTE: lpmm's 4-bit AdamW does not support BF16 weights.
5858

5959
## Optimizer CPU offload
6060

61-
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.
61+
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.
6262

6363
```python
6464
import torch
@@ -87,6 +87,17 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
8787
optim.load_state_dict(ckpt["optim"])
8888
```
8989

90+
`CPUOffloadOptimizer` is not compatible with PyTorch's built-in LR scheduler because it only acts as a wrapper around the actual optimizers (and extra logic for moving data around). To adjust the LR, you have to manually update it like follows (in fact you can use the below code for all PyTorch optimizers too):
91+
92+
```python
93+
lr = ... # compute your desired LR value
94+
for param_group in optim.param_groups:
95+
if isinstance(param_group["lr"], torch.Tensor):
96+
param_group["lr"].fill_(lr)
97+
else:
98+
param_group["lr"] = lr
99+
```
100+
90101
NOTE:
91102
- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step.
92103
- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).

0 commit comments

Comments
 (0)