Skip to content

Error using torchao.prototype.low_bit_optim.AdamW8bit #1978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
acisseJZhong opened this issue Nov 9, 2024 · 20 comments
Closed

Error using torchao.prototype.low_bit_optim.AdamW8bit #1978

acisseJZhong opened this issue Nov 9, 2024 · 20 comments
Labels
bug Something isn't working

Comments

@acisseJZhong
Copy link
Contributor

Running the full distributed finetune recipe, I want to save memory using torchao.prototype.low_bit_optim.AdamW8bit optimizer. This is how I used it:

optimizer:
  _component_: torchao.prototype.low_bit_optim.AdamW8bit
  lr: 2e-5

and this is the error log I am seeing:

Traceback (most recent call last):
  File "/torchtune/recipes/full_finetune_distributed.py", line 845, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/torchtune/recipes/full_finetune_distributed.py", line 840, in recipe_main
    recipe.train()
  File "/torchtune/recipes/full_finetune_distributed.py", line 743, in train
    self._optimizer.step()
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/optim/optimizer.py", line 493, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/prototype/low_bit_optim/adam.py", line 96, in step
    torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 556, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1423, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 549, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 977, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 708, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 743, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1348, in transform_code_object
    transformations(instructions, code_options)
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 233, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2909, in run
    super().run()
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1027, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3100, in RETURN_VALUE
    self._return(inst)
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3085, in _return
    self.output.compile_subgraph(
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1176, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1414, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1463, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1512, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1493, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/__init__.py", line 2294, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1707, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1103, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1079, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 635, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 226, in inner
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Mutations on non-contiguous inputs are currently not allowed on tensor subclasses
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Wondering how I should approach to solve this bug? Many thanks!

@acisseJZhong acisseJZhong added the bug Something isn't working label Nov 9, 2024
@felipemello1
Copy link
Contributor

hey @acisseJZhong , is this with one of our models? If you could share the config and make it reproducible, we could ping the folks from torchao.

@gau-nernst
Copy link
Contributor

@acisseJZhong What is your PyTorch and torchao version? I recommend using PyTorch nightly (>=2.6) for torchao's AdamW8bit + FSDP, since there are mysterious errors otherwise. For single GPU, it should be fine for PyTorch>=2.3.

@acisseJZhong
Copy link
Contributor Author

Hi @gau-nernst I am using torch 2.6.0 torchao 0.7.0, would that be a problem?

@gau-nernst
Copy link
Contributor

Hmm that's unexpected. Can you show me the exact PyTorch nightly version (with date), as well as the torchtune config that encounters this error (also your distributed setup i.e. What GPU and how many)? I will try to reproduce from my side.

@acisseJZhong
Copy link
Contributor Author

acisseJZhong commented Nov 11, 2024

I am using

Name: torch
Version: 2.6.0.dev20241105+cu121

and

Name: torchao
Version: 0.7.0.dev20241105+cu121

I was running a customized model but I just ran the llama3_2_vision/11B_full config with

optimizer:
  _component_: torchao.prototype.low_bit_optim.AdamW8bit
  lr: 2e-5

Commnd I used is tune run --nproc_per_node 4 full_finetune_distributed --config llama3_2_vision/11B_full
and it's giving the same error

[rank1]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 635, in _create_aot_dispatcher_function
[rank1]:     fw_metadata = run_functionalized_fw_and_collect_metadata(
[rank1]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 226, in inner
[rank1]:     raise RuntimeError(
[rank1]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank1]: RuntimeError: Mutations on non-contiguous inputs are currently not allowed on tensor subclasses

@gau-nernst thanks please let me know if you are seeing a different error.

@gau-nernst
Copy link
Contributor

Using the following versions

  • torch==2.6.0.dev20241107+cu121
  • torchao==0.7.0+git2ba1a61f

This is mainly a problem of DTensor + torch.compile

One problem is that positional encodings in torchtune are not contiguous after resizing. Adding the following patch will make the previous error go away.

Patch torchtune
diff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py
index cd1ea594..09a98862 100644
--- a/torchtune/models/clip/_position_embeddings.py
+++ b/torchtune/models/clip/_position_embeddings.py
@@ -319,7 +319,7 @@ class TiledTokenPositionalEmbedding(nn.Module):
         # add cls token back in
         local_pos_embed = torch.cat([cls_token, local_pos_embed], dim=0)
 
-        return local_pos_embed
+        return local_pos_embed.contiguous()
 
     # TODO: Switch to public method after 2.5 is stable
     @staticmethod
@@ -436,7 +436,7 @@ class TiledTokenPositionalEmbedding(nn.Module):
         # add cls token back in
         global_pos_embed = torch.cat([cls_embed, pos_embed], dim=2)
 
-        return global_pos_embed
+        return global_pos_embed.contiguous()
 
     def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
         """
@@ -633,7 +633,7 @@ class TilePositionalEmbedding(nn.Module):
         )
         # permute to the original shape
         embedding = embedding.permute(2, 3, 0, 1)
-        return embedding
+        return embedding.contiguous()
 
     def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
         """

With the above patch, there is a new error

[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3813, in _create_symbolic_sizes_strides_storage_offset                                
[rank1]:     assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}"                                                                                                                
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                  
[rank1]: AssertionError: 2 != 1                                                                                                                                                                
                                                                                                                                                                                               
[rank1]: from user code:                                                                                                                                                                       
[rank1]:    File "/root/torchtune/ao/torchao/prototype/low_bit_optim/adam.py", line 157, in single_param_adam                                                                                  
[rank1]:     p_f32 = p.float() 

Again, this is related to DTensor + torch.compile(). pytorch/ao#652. For some reasons, torch.compile() will try to use dynamic shape for DTensor, even though we specifically request it not to. I will try if I can make a small reproducible snippet and create an issue in PyTorch core.

@acisseJZhong In the meantime, you can apply the following patch to torchao to unblock your use case. (Note that now it only works for FSDP2, which uses DTensor and has attribute ._local_tensor)

Patch torchao
diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py
index 1c371897..052f469d 100644
--- a/torchao/prototype/low_bit_optim/adam.py
+++ b/torchao/prototype/low_bit_optim/adam.py
@@ -112,11 +112,11 @@ class _AdamBase(Optimizer):
                         )
 
                     torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
-                        p,
-                        grad,
+                        p._local_tensor,
+                        grad._local_tensor,
                         state["step"],
-                        state["exp_avg"],
-                        state["exp_avg_sq"],
+                        state["exp_avg"]._local_tensor,
+                        state["exp_avg_sq"]._local_tensor,
                         state.get("max_exp_avg_sq", None),
                         group["lr"],
                         group["betas"][0],

The torchtune patch can be submitted as a PR to torchtune cc @felipemello1 @ebsmothers. Not sure if I will apply to torchao patch to torchao though, because technically it is a DTensor+torch.compile problem, and the general consensus from PyTorch optimizer team is that I shouldn't need to unwrap DTensor for optimizer.

@gau-nernst
Copy link
Contributor

@acisseJZhong Can you try if pytorch/ao#1269 fixes your issue (instead of unwarping DTensor with ._local_tensor)? I don't have access to a test setup atm, will try tmr.

@acisseJZhong
Copy link
Contributor Author

Hey @gau-nernst I patched the fix in my local torchao version, and it's failing with the same error :(

@gau-nernst
Copy link
Contributor

Which same error? Mutations on non-contiguous inputs Or assert len(dynamic_sizes) == dim

Have you applied the torchtune patch? It's merged to main, so you can try main torchtune also.

For your own model, you also need to check if all params in your model is contiguous.

@acisseJZhong
Copy link
Contributor Author

I patched both torchtune and torchao's fixes you linked before, and I am seeing the same error in my initial bug report:

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Mutations on non-contiguous inputs are currently not allowed on tensor subclasses
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I am seeing this for both llama3.2 vision 11b full finetune and my own model. Were you able to run 11b vision with this optimizer?

@gau-nernst
Copy link
Contributor

I could run Llama3.2 vision 11B with my original patch (add .contiguous() to torchtune and add ._local_tensor for torchao). Let me confirm that again.

The Mutations on non-contiguous inputs error seems to come from contiguous tensor in torchtune though 🤔

@acisseJZhong
Copy link
Contributor Author

acisseJZhong commented Nov 13, 2024

I patched the fixes in https://github.com/pytorch/torchtune/pull/1986/files that added .contiguous() and added .detach() for local torchao in my conda env. https://github.com/pytorch/ao/pull/1269/files

What is this _local_tensor change? I might be missing this.

@gau-nernst
Copy link
Contributor

gau-nernst commented Nov 13, 2024

For the torchao patch, originally I fixed it with this #1978 (comment) (see Patch torchao).

Then I came up with another approach in pytorch/ao#1269 as you have seen. You only need to apply either one of the patches.

Just rmb I tested it with 2xA100, while you ran with 4x GPU. It may matter as some params in ViT (e.g. pos_embed) has very small first dim, so I'm not sure how FSDP will shard them if first dim < num_gpu. Perhaps you can try with 2x GPU setup also?

@acisseJZhong
Copy link
Contributor Author

thanks for replying! I tried again with the llama3.2 vision 11b model, the mutations on non-contiguous inputs error is fixed! Now I am seeing another error

[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_method lerp(*(DTensor(local_tensor=OptimState8bit(signed=True, block_size=256, shape=(401, 1280), device=cuda:0, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(401, 1280)), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), 0.09999999999999998), **{}):
[rank0]: Attempting to broadcast a dimension of length 1601 at -2! Mismatching argument at index 1 had torch.Size([1601, 1280]); but expected shape should be broadcastable to [1604, 1280]

[rank0]: from user code:
[rank0]:    File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/prototype/low_bit_optim/adam.py", line 145, in single_param_adam
[rank0]:     exp_avg_f32 = exp_avg.float().lerp(grad_f32, 1 - beta1)

Regardless running my own model, I am still seeing the noncontiguous error. Wondering how did you managed to identify which tensor is noncontiguous. In my case, I am not using vision encoder, just the regular rope here https://github.com/pytorch/torchtune/blob/main/torchtune/modules/position_embeddings.py

@gau-nernst
Copy link
Contributor

gau-nernst commented Nov 13, 2024

I just re-ran Llama3.2 vision 11B again with .detach() patch, and getting the same error as yours. However, if I use ._local_tensor patch, it works. Investigating now. In the mean time, you can use ._local_tensor patch to unblock your case.

Regarding non-contiguous tensor for your model, you can try doing something like this after creating the model and loading its weights

for name, p in model.name_parameters():
    if not p.is_contiguous():
        print(name)

Then you know which params are not contiguous, and add .contiguous() in appropriate places. Note that if you use .load_state_dict(assign=True), you also need to make sure the params in state dict are contiguous also. This is what happens in torchtune: torchtune use load state dict hooks to resize the ViT pos_embed, which became non-contiguous, and use assign=True to the model. Hence, the patch for torchtune was to call .contiguous() on the resize pos_embed function, which modifies the state dict, not on the model.

@gau-nernst
Copy link
Contributor

gau-nernst commented Nov 13, 2024

[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_method lerp(*(DTensor(local_tensor=OptimState8bit(signed=True, block_size=256, shape=(401, 1280), device=cuda:0, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(401, 1280)), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), 0.09999999999999998), **{}):
[rank0]: Attempting to broadcast a dimension of length 1601 at -2! Mismatching argument at index 1 had torch.Size([1601, 1280]); but expected shape should be broadcastable to [1604, 1280]

[rank0]: from user code:
[rank0]:    File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/prototype/low_bit_optim/adam.py", line 145, in single_param_adam
[rank0]:     exp_avg_f32 = exp_avg.float().lerp(grad_f32, 1 - beta1)

This error is because the 1st dim of the param is not divisible by world_size (1601, 1280), thus there is uneven sharding. Seems like torch.compile() + DTensor will not be happy when this happens. It seems unwrapping DTensor with ._local_tensor is the only way to handle this. I will consult with distributed folks on this. Also, I discovered a bug in selecting which param for low-bit optim: I have to take into account world size when checking for multiple of block_size. Hence, low-bit optim might be incorrect when there are odd-1st-dim params, even with ._local_tensor patch.

If your own model does not have odd 1st dimension (i.e. 1st dim is divisible by world_size), this shouldn't matter for you. You only need to check for non-contiguous params. Meanwhile, I will try to push some fixes to torchao...

@gau-nernst
Copy link
Contributor

The latest error was because I constructed DTensor wrongly when there was uneven sharding (i.e. 1st dim is not divisible by world size). I pushed a fix to pytorch/ao#1269. Can you try this latest patch?

@felipemello1
Copy link
Contributor

@gau-nernst , i think we can take a look at what param has this shape. i.e. if it is in the model, or if its an activation, of something like that. And on our side we can possibly pad it it 1608, without having to change the AdamW8Bit.. I remember seeing this error before in a different context.

@gau-nernst
Copy link
Contributor

@felipemello1 AdamW8bit code needs to be updated anyway, the current one is buggy (hence this error happens). Even if we pad it to 1608, if world_size is larger, there will still be uneven sharding i.e. 1608 / 16 = 100.5

Don't have access to a machine atm, but I'm guessing (1601, 1280) is the positional embedding with CLS token i.e. num_visual_tokens + 1

e.g.

n_tokens_per_tile = patch_grid_size**2 + 1 # +1 for cls token
scale = embed_dim**-0.5
self.positional_embedding = nn.Parameter(
scale * torch.randn((n_tokens_per_tile, embed_dim))
)

@acisseJZhong Do you mind confirm that the latest patch at pytorch/ao#1269 works for you? Once you confirm it, I will merge the torchao PR.

@acisseJZhong
Copy link
Contributor Author

@gau-nernst I confirmed llama3.2 11b vision now works with torchao.prototype.low_bit_optim.AdamW8bit. For my own model, it's still having the non contiguous tensor error and i will debug myself. Thanks for unblocking me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants