-
Notifications
You must be signed in to change notification settings - Fork 261
FSDP 2 low bit optim broken on pytorch nightlies #652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Just realized ao/test/prototype/test_low_bit_optim.py Line 232 in 261d0a4
![]() |
The error message is very cryptic. |
Yeah the version problem is kinda getting out of hand, I'll fix that asap Regarding the error usually @awgu and @weifengpy are usually my gotos for fsdp2 issues |
I think low-bit optimizer + FSDP2 is actually low-bit optimizer + DTensor + |
(taking a look) |
The problem is that we have a pretty complicated input to the compiled region: our input is a I have a min repro here pytorch/pytorch#133274. In the meantime, I also found that this tweak gets me past the error, although I'm not sure that we actually want to land it to eager FSDP2 (cc @awgu ):
|
Thank you for the quick debug. May I ask
You mentioned the input being concerned has |
I was looking at the values of And empirically,
|
@bdhirsh I see, thank you for the clarification. The subclass you were referring to is DTensor, not my custom subclass for quantized optimizer state. It makes sense. But it also raises another question. How come other FSDP2 tests in torchao did not fail 😅. Then I rmb NF4 is not trainable, so it won't have In the end, is correct to say that the bug is more about FSDP2+torch.compile(optim_step)? If it is not isolated to custom optimizer, perhaps we can add some tests for this scenario in PyTorch core or other repos too. |
yeah, I could definitely believe that this is true (I don't have bandwidth to add those tests, but if someone wants to try making a smaller repro that doesn't use your low bit optimizer they are welcome to 😄 ) then again, I think this is a pretty one-off bug that we just expected to be very rarely hit (we haven't had to excercise a lot of code in compile where our tensor inputs to the graph also have |
I can confirm normal optimizers have this bug too import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer
batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
vocab_size=vocab_size,
max_seq_len=seq_len,
)
model = Transformer(model_args).cuda()
for m in model.layers:
fully_shard(m)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, foreach=False, fused=False)
# compile optimizer
optim.step = torch.compile(optim.step)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
model(inp).mean().backward()
optim.step()
optim.zero_grad() Run with Agree with your last point 😄! Hopefully the fix in PyTorch core is coming soon! Thank you for the help! |
@bdhirsh I noticed that the FSDP test for low-bit optim now passed with torch nightly. Was it fixed in core recently? I didn't see any updates in pytorch/pytorch#133274 |
hmm that's strange - i ran the non-subclass repro you put above locally and it still fails for me:
|
Hmm, I think it might be because I change the way I compile the optim step. Now I static-shape compile optim step for each param, instead of optim step for all params #812. In that case the issue in pytorch core is still there, but we can probably close this issue? |
ah yeah, great - this is definitely just a bug at the intersection of subclasses + dynamic shapes + optimizer/gradient, so if you're ok with static shapes only for now (which might be better for perf anyway), closing this issue sounds fine to me |
To repro:
python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2
Logs
The text was updated successfully, but these errors were encountered: