You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao/prototype/low_bit_optim/README.md
+13-2Lines changed: 13 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -6,7 +6,7 @@ This folder implements:
6
6
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
7
7
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)
8
8
9
-
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.
9
+
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. Thus, your platform must support `torch.compile()` to use these optimizers. We only test on CPU and CUDA, so there might be bugs or errors on other platforms.
10
10
11
11
## Usage
12
12
@@ -58,7 +58,7 @@ NOTE: lpmm's 4-bit AdamW does not support BF16 weights.
58
58
59
59
## Optimizer CPU offload
60
60
61
-
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.
61
+
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.
`CPUOffloadOptimizer` is not compatible with PyTorch's built-in LR scheduler because it only acts as a wrapper around the actual optimizers (and extra logic for moving data around). To adjust the LR, you have to manually update it like follows (in fact you can use the below code for all PyTorch optimizers too):
91
+
92
+
```python
93
+
lr =...# compute your desired LR value
94
+
for param_group in optim.param_groups:
95
+
ifisinstance(param_group["lr"], torch.Tensor):
96
+
param_group["lr"].fill_(lr)
97
+
else:
98
+
param_group["lr"] = lr
99
+
```
100
+
90
101
NOTE:
91
102
- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step.
92
103
- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).
0 commit comments