Skip to content

Commit d26bcb6

Browse files
gau-nernstHDCharles
authored andcommitted
[Low-bit optim] Improve compile time + Fix PyTorch 2.3 support for 4-bit optim (#812)
* disable recompile limit * remove _prepare_param_groups() * re-enable FSDP test. update ViT benchmarks * update * update * update readme
1 parent 51011b4 commit d26bcb6

File tree

3 files changed

+69
-190
lines changed

3 files changed

+69
-190
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_quantize_4bit_with_qmap_compile(self, device):
7575

7676

7777
class TestOptim(TestCase):
78-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.3")
78+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
7979
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
8080
@parametrize("dtype", [torch.float32, torch.bfloat16])
8181
@parametrize("device", _DEVICES)
@@ -84,10 +84,7 @@ def test_optim_smoke(self, optim_name, dtype, device):
8484
if not TORCH_VERSION_AT_LEAST_2_4:
8585
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
8686
if torch.cuda.get_device_capability() < (8, 9):
87-
pytest.skip("FP8 requires compute capability >= 8.9")
88-
89-
# reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test
90-
torch._dynamo.reset_code_caches()
87+
pytest.skip("FP8 CUDA requires compute capability >= 8.9")
9188

9289
model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32))
9390
model.to(device=device, dtype=dtype)
@@ -232,12 +229,11 @@ def world_size(self) -> int:
232229
return 2
233230

234231
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default")
235-
@pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="https://github.com/pytorch/ao/issues/652")
236232
@skip_if_lt_x_gpu(2)
237233
def test_fsdp2(self):
238-
optim_classes = [low_bit_optim.Adam8bit, low_bit_optim.Adam4bit]
234+
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
239235
if torch.cuda.get_device_capability() >= (8, 9):
240-
optim_classes.append(low_bit_optim.AdamFp8)
236+
optim_classes.append(low_bit_optim.AdamWFp8)
241237

242238
self.run_subtests(
243239
{"optim_cls": optim_classes},
@@ -252,9 +248,6 @@ def _test_fsdp2(self, optim_cls):
252248
TransformerBlock,
253249
)
254250

255-
# seems like cache_size_limit is shared between FSDP processes?
256-
torch._dynamo.config.cache_size_limit = 8 * self.world_size
257-
258251
batch_size = 3
259252
vocab_size = 1024
260253
seq_len = 64

torchao/prototype/low_bit_optim/README.md

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,35 @@ NOTE:
2727
- The low-bit optimizers require PyTorch >= 2.3
2828
- For FP8 optimizers on CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required.
2929
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
30-
- The first training step is expected to be slow since the optimizer needs to be compiled.
3130

3231
## Benchmarks
3332

34-
Fine-tune [timm](https://github.com/huggingface/pytorch-image-models)'s ViT-H (630M params) on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset. BF16 AMP, 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed. Benchmark script is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
33+
Fine-tune [timm](https://github.com/huggingface/pytorch-image-models)'s [ViT-H](https://huggingface.co/timm/vit_huge_patch14_224.orig_in21k) (630M params) on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset. PyTorch 2.4, BF16 AMP, compiled model, 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed. Benchmark script is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
3534

36-
AdamW impl | Max memory (GB) | imgs/s | accuracy
37-
----------------|-----------------|--------|----------
38-
PyTorch (fused) | 12.23 | 41.8 | 94.38
39-
bnb 8-bit | 8.32 | 43.6 | 94.18
40-
ao 8-bit | 8.33 | 42.6 | 94.25
41-
ao FP8 E4M3 | 9.27 | 44.1 | 94.40
42-
lpmm 4-bit | 7.72 | 46.0 | 94.29
43-
ao 4-bit | 7.72 | 40.0 | 94.03
44-
lpmm 4-bit (*) | 7.74 | 26.6 | 94.25
35+
AdamW impl | Peak memory allocated (GB) | imgs/s | accuracy
36+
----------------|----------------------------|--------|----------
37+
PyTorch (fused) | 12.23 | 41.9 | 94.52
38+
bnb 8-bit | 8.32 | 43.6 | 94.54
39+
ao 8-bit | 8.33 | 42.5 | 94.30
40+
ao FP8 E4M3 | 8.33 | 43.2 | 94.13
41+
lpmm 4-bit | 7.72 | 46.1 | 94.40
42+
ao 4-bit | 7.72 | 42.4 | 94.13
43+
lpmm 4-bit (*) | 7.74 | 26.7 | 94.10
4544

4645
(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.
4746

48-
Fine-tune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. Full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with [torchtune](https://github.com/pytorch/torchtune). See [#746](https://github.com/pytorch/ao/pull/746) for more details.
47+
Fine-tune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. PyTorch 2.4, full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with [torchtune 52d1b838](https://github.com/pytorch/torchtune/tree/52d1b838c1c35b5e75fddf8776be400adc36dff5). See [#812](https://github.com/pytorch/ao/pull/812) for more details.
4948

50-
AdamW impl | Max memory (GB) | toks/s | `truthfulqa_mc2` acc | Compile time
51-
-----------------|-----------------|--------|----------------------|-------------
52-
Not fine-tuned | - | - | 38.95 | -
53-
PyTorch (fused) | 52 | ~4500 | 42.12 | ~4 min
54-
bnb 8-bit | 39 | ~4000 | 41.98 | ~4 min
55-
ao 8-bit | 39 | ~4000 | 42.41 | ~12 min
56-
ao 4-bit | 33 | ~3600 | 42.34 | ~4 min
49+
AdamW impl | Peak memory allocated (GB) | toks/s | `truthfulqa_mc2` acc
50+
-----------------|----------------------------|--------|----------------------
51+
Not fine-tuned | - | - | 38.95
52+
PyTorch (fused) | 51.6 | 3200 | 42.61
53+
bnb 8-bit | 39.3 | 3000 | 42.75
54+
ao 8-bit | 39.1 | 2900 | 41.50
55+
ao 4-bit | 33.2 | 2900 | 42.27
5756

5857
NOTE: lpmm's 4-bit AdamW does not support BF16 weights.
5958

60-
### Note on compile times
61-
62-
There are 2 approaches to compile optimizer step in low-bit optim:
63-
64-
1. Compile optim step for single param i.e. `torch.compile(single_param_adam)`
65-
2. Compile optim step for all params i.e. `torch.compile(param_groups_adam)`
66-
67-
Currently Adam8bit and AdamFp8 use approach (2) (with static shape) since it is faster (but compile much slower), while Adam4bit uses approach (1) (with dynamic shape) since there are excessive memory usage for "Adam4bit + approach (2)". Approach (1) requires dynamic shape to avoid hitting recompiles limit.
68-
6959
## Optimizer CPU offload
7060

7161
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.

torchao/prototype/low_bit_optim/adam.py

Lines changed: 47 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -52,75 +52,63 @@ def _new_buffer(self, p: Tensor, signed: bool):
5252
out = torch.zeros_like(p)
5353
return out
5454

55-
def _prepare_param_groups(self):
56-
param_groups = []
57-
58-
for group in self.param_groups:
59-
_group = []
60-
61-
for p in group["params"]:
62-
if p.grad is None:
63-
continue
64-
65-
grad = p.grad
66-
if grad.is_sparse:
67-
raise RuntimeError("Sparse gradient is not supported")
68-
69-
state = self.state[p]
70-
71-
# State initialization
72-
if len(state) == 0:
73-
state["step"] = torch.tensor(0.0)
74-
state["exp_avg"] = self._new_buffer(p, True)
75-
state["exp_avg_sq"] = self._new_buffer(p, False)
76-
if group["amsgrad"]:
77-
state["max_exp_avg_sq"] = self._new_buffer(p, False)
78-
79-
state["step"] += 1
80-
81-
if not isinstance(group["lr"], Tensor):
82-
raise RuntimeError(
83-
"lr was changed to a non-Tensor object. If you want to update lr, please use "
84-
"optim.param_groups[0]['lr'].fill_(new_lr)"
85-
)
86-
87-
p_grad_state = (
88-
p,
89-
grad,
90-
state["step"],
91-
state["exp_avg"],
92-
state["exp_avg_sq"],
93-
state.get("max_exp_avg_sq", None),
94-
)
95-
_group.append(p_grad_state)
96-
97-
param_groups.append((_group, group["lr"], group["betas"], group["weight_decay"], group["eps"]))
98-
99-
return param_groups
100-
10155
@torch.no_grad()
10256
def step(self, closure=None):
10357
loss = None
10458
if closure is not None:
10559
with torch.enable_grad():
10660
loss = closure()
10761

108-
param_groups = self._prepare_param_groups()
62+
# for a given model, the number of different argument combinations to single_param_adam() is fixed.
63+
# thus, it is safe to disable cache limit without the risk of always re-compiling.
64+
with torch._dynamo.utils.disable_cache_limit():
65+
for group in self.param_groups:
66+
for p in group["params"]:
67+
if p.grad is None:
68+
continue
69+
70+
grad = p.grad
71+
if grad.is_sparse:
72+
raise RuntimeError("Sparse gradient is not supported")
73+
74+
state = self.state[p]
75+
76+
# State initialization
77+
if len(state) == 0:
78+
state["step"] = torch.tensor(0.0)
79+
state["exp_avg"] = self._new_buffer(p, True)
80+
state["exp_avg_sq"] = self._new_buffer(p, False)
81+
if group["amsgrad"]:
82+
state["max_exp_avg_sq"] = self._new_buffer(p, False)
83+
84+
state["step"] += 1
85+
86+
if not isinstance(group["lr"], Tensor):
87+
raise RuntimeError(
88+
"lr was changed to a non-Tensor object. If you want to update lr, please use "
89+
"optim.param_groups[0]['lr'].fill_(new_lr)"
90+
)
91+
92+
torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
93+
p,
94+
grad,
95+
state["step"],
96+
state["exp_avg"],
97+
state["exp_avg_sq"],
98+
state.get("max_exp_avg_sq", None),
99+
group["lr"],
100+
group["betas"][0],
101+
group["betas"][1],
102+
group["weight_decay"],
103+
group["eps"],
104+
self.is_adamw,
105+
)
109106

110-
# static compile optim step for all params in a single graph
111-
torch.compile(param_groups_adam, fullgraph=True)(param_groups, self.is_adamw)
112107
return loss
113108

114109

115-
def param_groups_adam(param_groups, is_adamw):
116-
for group, lr, (beta1, beta2), weight_decay, eps in param_groups:
117-
for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group:
118-
single_param_adam(
119-
p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps, is_adamw
120-
)
121-
122-
123110
# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default
111+
# and param tensor subclass that implements aten.add_.Tensor, and aten.addcdiv_.default
124112
def single_param_adam(
125113
p: Tensor,
126114
grad: Tensor,
@@ -198,53 +186,7 @@ def __init__(
198186

199187
@staticmethod
200188
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
201-
return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device)
202-
203-
@staticmethod
204-
def _unwrap_dtensor(p: Tensor):
205-
return p._local_tensor if isinstance(p, DTensor) else p
206-
207-
@torch.no_grad()
208-
def step(self, closure=None):
209-
loss = None
210-
if closure is not None:
211-
with torch.enable_grad():
212-
loss = closure()
213-
214-
param_groups = self._prepare_param_groups()
215-
216-
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
217-
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.
218-
219-
# NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
220-
# PyTorch 2.3 and 2.4
221-
# calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
222-
# correctly for the tensor subclass.
223-
224-
# unwrap DTensor since DTensor does not work well with dynamic compile
225-
# flatten p, grad, and optim state to avoid recompilation
226-
for group, lr, (beta1, beta2), weight_decay, eps in param_groups:
227-
for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group:
228-
# DTensor._local_tensor has .requires_grad = False
229-
# to avoid recompilation, set p.requires_grad = False and restore it after optim step
230-
p.requires_grad_(False)
231-
torch.compile(single_param_adam, fullgraph=True, dynamic=True)(
232-
self._unwrap_dtensor(p).view(-1),
233-
self._unwrap_dtensor(grad).view(-1),
234-
step,
235-
self._unwrap_dtensor(exp_avg),
236-
self._unwrap_dtensor(exp_avg_sq),
237-
self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None,
238-
lr,
239-
beta1,
240-
beta2,
241-
weight_decay,
242-
eps,
243-
self.is_adamw,
244-
)
245-
p.requires_grad_(True)
246-
247-
return loss
189+
return OptimState4bit.zeros(p.shape, signed, block_size, p.device)
248190

249191

250192
class AdamFp8(_AdamBase):
@@ -301,53 +243,7 @@ def __init__(
301243

302244
@staticmethod
303245
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
304-
return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device)
305-
306-
@staticmethod
307-
def _unwrap_dtensor(p: Tensor):
308-
return p._local_tensor if isinstance(p, DTensor) else p
309-
310-
@torch.no_grad()
311-
def step(self, closure=None):
312-
loss = None
313-
if closure is not None:
314-
with torch.enable_grad():
315-
loss = closure()
316-
317-
param_groups = self._prepare_param_groups()
318-
319-
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
320-
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.
321-
322-
# NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
323-
# PyTorch 2.3 and 2.4
324-
# calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
325-
# correctly for the tensor subclass.
326-
327-
# unwrap DTensor since DTensor does not work well with dynamic compile
328-
# flatten p, grad, and optim state to avoid recompilation
329-
for group, lr, (beta1, beta2), weight_decay, eps in param_groups:
330-
for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group:
331-
# DTensor._local_tensor has .requires_grad = False
332-
# to avoid recompilation, set p.requires_grad = False and restore it after optim step
333-
p.requires_grad_(False)
334-
torch.compile(single_param_adam, fullgraph=True, dynamic=True)(
335-
self._unwrap_dtensor(p).view(-1),
336-
self._unwrap_dtensor(grad).view(-1),
337-
step,
338-
self._unwrap_dtensor(exp_avg),
339-
self._unwrap_dtensor(exp_avg_sq),
340-
self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None,
341-
lr,
342-
beta1,
343-
beta2,
344-
weight_decay,
345-
eps,
346-
self.is_adamw,
347-
)
348-
p.requires_grad_(True)
349-
350-
return loss
246+
return OptimState4bit.zeros(p.shape, signed, block_size, p.device)
351247

352248

353249
class AdamWFp8(_AdamBase):

0 commit comments

Comments
 (0)