-
Notifications
You must be signed in to change notification settings - Fork 610
Error using torchao.prototype.low_bit_optim.AdamW8bit
#1978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
hey @acisseJZhong , is this with one of our models? If you could share the config and make it reproducible, we could ping the folks from torchao. |
@acisseJZhong What is your PyTorch and torchao version? I recommend using PyTorch nightly (>=2.6) for torchao's AdamW8bit + FSDP, since there are mysterious errors otherwise. For single GPU, it should be fine for PyTorch>=2.3. |
Hi @gau-nernst I am using torch 2.6.0 torchao 0.7.0, would that be a problem? |
Hmm that's unexpected. Can you show me the exact PyTorch nightly version (with date), as well as the torchtune config that encounters this error (also your distributed setup i.e. What GPU and how many)? I will try to reproduce from my side. |
I am using
and
I was running a customized model but I just ran the llama3_2_vision/11B_full config with
Commnd I used is
@gau-nernst thanks please let me know if you are seeing a different error. |
Using the following versions
This is mainly a problem of DTensor + torch.compile One problem is that positional encodings in torchtune are not contiguous after resizing. Adding the following patch will make the previous error go away. Patch torchtunediff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py
index cd1ea594..09a98862 100644
--- a/torchtune/models/clip/_position_embeddings.py
+++ b/torchtune/models/clip/_position_embeddings.py
@@ -319,7 +319,7 @@ class TiledTokenPositionalEmbedding(nn.Module):
# add cls token back in
local_pos_embed = torch.cat([cls_token, local_pos_embed], dim=0)
- return local_pos_embed
+ return local_pos_embed.contiguous()
# TODO: Switch to public method after 2.5 is stable
@staticmethod
@@ -436,7 +436,7 @@ class TiledTokenPositionalEmbedding(nn.Module):
# add cls token back in
global_pos_embed = torch.cat([cls_embed, pos_embed], dim=2)
- return global_pos_embed
+ return global_pos_embed.contiguous()
def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
"""
@@ -633,7 +633,7 @@ class TilePositionalEmbedding(nn.Module):
)
# permute to the original shape
embedding = embedding.permute(2, 3, 0, 1)
- return embedding
+ return embedding.contiguous()
def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
""" With the above patch, there is a new error
Again, this is related to DTensor + torch.compile(). pytorch/ao#652. For some reasons, torch.compile() will try to use dynamic shape for DTensor, even though we specifically request it not to. I will try if I can make a small reproducible snippet and create an issue in PyTorch core. @acisseJZhong In the meantime, you can apply the following patch to torchao to unblock your use case. (Note that now it only works for FSDP2, which uses DTensor and has attribute Patch torchaodiff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py
index 1c371897..052f469d 100644
--- a/torchao/prototype/low_bit_optim/adam.py
+++ b/torchao/prototype/low_bit_optim/adam.py
@@ -112,11 +112,11 @@ class _AdamBase(Optimizer):
)
torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
- p,
- grad,
+ p._local_tensor,
+ grad._local_tensor,
state["step"],
- state["exp_avg"],
- state["exp_avg_sq"],
+ state["exp_avg"]._local_tensor,
+ state["exp_avg_sq"]._local_tensor,
state.get("max_exp_avg_sq", None),
group["lr"],
group["betas"][0], The torchtune patch can be submitted as a PR to torchtune cc @felipemello1 @ebsmothers. Not sure if I will apply to torchao patch to torchao though, because technically it is a DTensor+torch.compile problem, and the general consensus from PyTorch optimizer team is that I shouldn't need to unwrap DTensor for optimizer. |
@acisseJZhong Can you try if pytorch/ao#1269 fixes your issue (instead of unwarping DTensor with |
Hey @gau-nernst I patched the fix in my local torchao version, and it's failing with the same error :( |
Which same error? Mutations on non-contiguous inputs Or assert len(dynamic_sizes) == dim Have you applied the torchtune patch? It's merged to main, so you can try main torchtune also. For your own model, you also need to check if all params in your model is contiguous. |
I patched both torchtune and torchao's fixes you linked before, and I am seeing the same error in my initial bug report:
I am seeing this for both llama3.2 vision 11b full finetune and my own model. Were you able to run 11b vision with this optimizer? |
I could run Llama3.2 vision 11B with my original patch (add The |
I patched the fixes in https://github.com/pytorch/torchtune/pull/1986/files that added What is this |
For the torchao patch, originally I fixed it with this #1978 (comment) (see Patch torchao). Then I came up with another approach in pytorch/ao#1269 as you have seen. You only need to apply either one of the patches. Just rmb I tested it with 2xA100, while you ran with 4x GPU. It may matter as some params in ViT (e.g. pos_embed) has very small first dim, so I'm not sure how FSDP will shard them if first dim < num_gpu. Perhaps you can try with 2x GPU setup also? |
thanks for replying! I tried again with the llama3.2 vision 11b model, the mutations on non-contiguous inputs error is fixed! Now I am seeing another error
Regardless running my own model, I am still seeing the noncontiguous error. Wondering how did you managed to identify which tensor is noncontiguous. In my case, I am not using vision encoder, just the regular rope here https://github.com/pytorch/torchtune/blob/main/torchtune/modules/position_embeddings.py |
I just re-ran Llama3.2 vision 11B again with Regarding non-contiguous tensor for your model, you can try doing something like this after creating the model and loading its weights for name, p in model.name_parameters():
if not p.is_contiguous():
print(name) Then you know which params are not contiguous, and add |
This error is because the 1st dim of the param is not divisible by world_size If your own model does not have odd 1st dimension (i.e. 1st dim is divisible by world_size), this shouldn't matter for you. You only need to check for non-contiguous params. Meanwhile, I will try to push some fixes to torchao... |
The latest error was because I constructed DTensor wrongly when there was uneven sharding (i.e. 1st dim is not divisible by world size). I pushed a fix to pytorch/ao#1269. Can you try this latest patch? |
@gau-nernst , i think we can take a look at what param has this shape. i.e. if it is in the model, or if its an activation, of something like that. And on our side we can possibly pad it it 1608, without having to change the AdamW8Bit.. I remember seeing this error before in a different context. |
@felipemello1 AdamW8bit code needs to be updated anyway, the current one is buggy (hence this error happens). Even if we pad it to 1608, if world_size is larger, there will still be uneven sharding i.e. 1608 / 16 = 100.5 Don't have access to a machine atm, but I'm guessing e.g. torchtune/torchtune/models/clip/_position_embeddings.py Lines 35 to 39 in 4b6877a
@acisseJZhong Do you mind confirm that the latest patch at pytorch/ao#1269 works for you? Once you confirm it, I will merge the torchao PR. |
@gau-nernst I confirmed llama3.2 11b vision now works with |
Running the full distributed finetune recipe, I want to save memory using
torchao.prototype.low_bit_optim.AdamW8bit
optimizer. This is how I used it:and this is the error log I am seeing:
Wondering how I should approach to solve this bug? Many thanks!
The text was updated successfully, but these errors were encountered: