Skip to content

FSDP 2 low bit optim broken on pytorch nightlies #652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
msaroufim opened this issue Aug 10, 2024 · 15 comments
Closed

FSDP 2 low bit optim broken on pytorch nightlies #652

msaroufim opened this issue Aug 10, 2024 · 15 comments
Assignees

Comments

@msaroufim
Copy link
Member

To repro: python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2

Logs

- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
  =========================== short test summary info ============================
  FAILED test/prototype/test_low_bit_optim.py::TestFSDP2::test_fsdp2 - RuntimeError: Process 0 exited with error code 10 and exception:
  Traceback (most recent call last):
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 664, in run_test
      getattr(self, test_name)()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 543, in wrapper
      fn()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2918, in wrapper
      method(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 184, in wrapper
      return func(*args, **kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 239, in test_fsdp2
      self.run_subtests(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_fsdp.py", line 1141, in run_subtests
      return run_subtests(self, *args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 882, in run_subtests
      test_fn(*test_args, **test_kwargs, **subtest_kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 284, in _test_fsdp2
      fsdp_optim.step()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/optim/optimizer.py", line 479, in wrapper
      out = func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 110, in step
      torch.compile(param_groups_adam, fullgraph=True)(param_groups)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
      return self._torchdynamo_orig_callable(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
      return _compile(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
      guarded_code = compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
      return _compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
      return StrobelightCompileTimeProfiler.profile_compile_time(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
      out_code = transform_code_object(code, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
      transformations(instructions, code_options)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
      tracer.run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
      super().run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
      while self.step():
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
      self.dispatch_table[inst.opcode](self, inst)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1048, in STORE_FAST
      self._store_fast(inst.argval)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1044, in _store_fast
      loaded_vt.set_name_hint(name)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
      return getattr(self.realize(), name)(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 63, in realize
      self._cache.realize()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 29, in realize
      self.vt = VariableBuilder(tx, self.source)(self.value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 337, in __call__
      vt = self._wrap(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 516, in _wrap
      return self.wrap_tensor(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1474, in wrap_tensor
      tensor_variable = wrap_fx_proxy(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1910, in wrap_fx_proxy
      return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2022, in wrap_fx_proxy_cls
      example_value = wrap_to_fake_tensor_and_record(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2589, in wrap_to_fake_tensor_and_record
      fake_e = wrap_fake_exception(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1459, in wrap_fake_exception
      return fn()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2590, in <lambda>
      lambda: tx.fake_mode.from_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2171, in from_tensor
      return self.fake_tensor_converter.from_real_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 374, in from_real_tensor
      out = self.meta_converter(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1642, in __call__
      r = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1543, in meta_tensor
      r.grad = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1434, in meta_tensor
      r = empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 847, in empty_create_subclass
      sub = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 833, in _empty_create_subclass
      new_empty_tensor = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 818, in _empty_create_subclass
      return self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1322, in meta_tensor
      base = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1429, in meta_tensor
      ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 752, in sym_sizes_strides_storage_offset
      return shape_env._create_symbolic_sizes_strides_storage_offset(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
      return retlog(fn(*args, **kwargs))
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3222, in _create_symbolic_sizes_strides_storage_offset
      assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}"
  AssertionError: 2 != 1
  
  from user code:
     File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 116, in param_groups_adam
      for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group:
  
  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
  
  
  You can suppress this exception and fall back to eager by setting:
      import torch._dynamo
      torch._dynamo.config.suppress_errors = True
  
  
  To execute this test, run the following from the base repo dir:
      python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2
  
  This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
  
  Process 1 exited with error code 10 and exception:
  Traceback (most recent call last):
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 664, in run_test
      getattr(self, test_name)()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 543, in wrapper
      fn()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2918, in wrapper
      method(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 184, in wrapper
      return func(*args, **kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 239, in test_fsdp2
      self.run_subtests(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_fsdp.py", line 1141, in run_subtests
      return run_subtests(self, *args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 882, in run_subtests
      test_fn(*test_args, **test_kwargs, **subtest_kwargs)
    File "/pytorch/ao/test/prototype/test_low_bit_optim.py", line 284, in _test_fsdp2
      fsdp_optim.step()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/optim/optimizer.py", line 479, in wrapper
      out = func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 110, in step
      torch.compile(param_groups_adam, fullgraph=True)(param_groups)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
      return self._torchdynamo_orig_callable(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
      return _compile(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
      guarded_code = compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
      return _compile_inner(code, one_graph, hooks, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
      return StrobelightCompileTimeProfiler.profile_compile_time(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
      return func(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
      out_code = transform_code_object(code, transform)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
      transformations(instructions, code_options)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
      return fn(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
      tracer.run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
      super().run()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
      while self.step():
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
      self.dispatch_table[inst.opcode](self, inst)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1048, in STORE_FAST
      self._store_fast(inst.argval)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1044, in _store_fast
      loaded_vt.set_name_hint(name)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
      return getattr(self.realize(), name)(*args, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 63, in realize
      self._cache.realize()
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 29, in realize
      self.vt = VariableBuilder(tx, self.source)(self.value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 337, in __call__
      vt = self._wrap(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 516, in _wrap
      return self.wrap_tensor(value)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1474, in wrap_tensor
      tensor_variable = wrap_fx_proxy(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1910, in wrap_fx_proxy
      return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2022, in wrap_fx_proxy_cls
      example_value = wrap_to_fake_tensor_and_record(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2589, in wrap_to_fake_tensor_and_record
      fake_e = wrap_fake_exception(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1459, in wrap_fake_exception
      return fn()
  Traceback (most recent call last):
    File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 102, in <module>
      main()
    File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 98, in main
      run_cmd_or_die(f"docker exec -t {container_name} /exec")
    File "/home/ec2-user/actions-runner/_work/ao/ao/test-infra/.github/scripts/run_with_env_secrets.py", line 39, in run_cmd_or_die
      raise RuntimeError(f"Command {cmd} failed with exit code {exit_code}")
  RuntimeError: Command docker exec -t 6a284e024d9e5fa50319dc78d9852c462a2c247de64ee9d0a3d6d326ab401309 /exec failed with exit code 1
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 2590, in <lambda>
      lambda: tx.fake_mode.from_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2171, in from_tensor
      return self.fake_tensor_converter.from_real_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 374, in from_real_tensor
      out = self.meta_converter(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1642, in __call__
      r = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1543, in meta_tensor
      r.grad = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1434, in meta_tensor
      r = empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 847, in empty_create_subclass
      sub = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 833, in _empty_create_subclass
      new_empty_tensor = _empty_create_subclass(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 818, in _empty_create_subclass
      return self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1322, in meta_tensor
      base = self.meta_tensor(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 1429, in meta_tensor
      ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 752, in sym_sizes_strides_storage_offset
      return shape_env._create_symbolic_sizes_strides_storage_offset(
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
      return retlog(fn(*args, **kwargs))
    File "/opt/conda/envs/venv/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3222, in _create_symbolic_sizes_strides_storage_offset
      assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}"
  AssertionError: 2 != 1
  
  from user code:
     File "/opt/conda/envs/venv/lib/python3.9/site-packages/torchao/prototype/low_bit_optim/adam.py", line 116, in param_groups_adam
      for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group:
  
  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
  
  
  You can suppress this exception and fall back to eager by setting:
      import torch._dynamo
      torch._dynamo.config.suppress_errors = True
  
  
  To execute this test, run the following from the base repo dir:
      python test/prototype/test_low_bit_optim.py TestFSDP2.test_fsdp2
  
  This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
  ==== 1 failed, 1135 passed, 267 skipped, 55 warnings in 1560.72s (0:26:00) =====
  Error: Process completed with exit code 1.
@gau-nernst
Copy link
Collaborator

gau-nernst commented Aug 11, 2024

Just realized TORCH_VERSION_AFTER_2_4 will return False in 2.4.0. Still got that old problem 🤣. So low bit optim FSDP2 test will not run in 2.4.0 CI.

@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required")

image

@gau-nernst
Copy link
Collaborator

The error message is very cryptic. AdamW8bit doesn't use dynamic shape though, so don't know why it pops up. And this error only happens to FSDP2 test, not the normal single-gpu test. Would you know who can take a look into this?

@msaroufim
Copy link
Member Author

Yeah the version problem is kinda getting out of hand, I'll fix that asap

Regarding the error usually @awgu and @weifengpy are usually my gotos for fsdp2 issues

@awgu
Copy link
Contributor

awgu commented Aug 11, 2024

I think low-bit optimizer + FSDP2 is actually low-bit optimizer + DTensor + torch.compile, for which @bdhirsh is probably the best.

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 12, 2024

(taking a look)

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 12, 2024

The problem is that we have a pretty complicated input to the compiled region: our input is a DTensor, that has a local_tensor._base, and also has a populated .grad field that is also a DTensor, which has a _local_tensor._base with a different number of dims compared to the original ._base.

I have a min repro here pytorch/pytorch#133274.

In the meantime, I also found that this tweak gets me past the error, although I'm not sure that we actually want to land it to eager FSDP2 (cc @awgu ):

diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index d739ffbcf96..c512ea7c37f 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -324,7 +324,7 @@ def foreach_reduce(
                 size=fsdp_param.sharded_size,
                 stride=fsdp_param.contiguous_sharded_stride,
                 storage_offset=flat_grad_offset,
-            )
+            ).detach()
             to_accumulate_grad = fsdp_param.sharded_param.grad is not None
             if fsdp_param.offload_to_cpu:
                 # Only overlap the D2H copy (copying to pinned memory) if not

@gau-nernst
Copy link
Collaborator

Thank you for the quick debug. May I ask

our input is a DTensor, that has a local_tensor._base, and also has a populated .grad field that is also a DTensor, which has a _local_tensor._base with a different number of dims compared to the original ._base

You mentioned the input being concerned has .grad field, indicating that it is a parameter. In the low bit optim test, only the optimizer states are tensor subclass, so they shouldn't have .grad field. I think something is not quite right here?

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 13, 2024

I was looking at the values of param_groups, which are the inputs to your torch.compile region, here: https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/adam.py#L110

And empirically, param_groups contains DTensor parameters with the above properties. Are you saying you don't expect the parameters themselves to be DTensors? Maybe @awgu would know better?

(Pdb) p type(param_groups[0][0][0][0])
<class 'torch.distributed._tensor.api.DTensor'>
(Pdb) p param_groups[0][0][0][0]._local_tensor._base.ndim
2
(Pdb) p param_groups[0][0][0][0].grad._local_tensor._base.ndim
1
(Pdb) p isinstance(param_groups[0][0][0][0], torch.nn.Parameter)
True

@gau-nernst
Copy link
Collaborator

@bdhirsh I see, thank you for the clarification. The subclass you were referring to is DTensor, not my custom subclass for quantized optimizer state. It makes sense.

But it also raises another question. How come other FSDP2 tests in torchao did not fail 😅. Then I rmb NF4 is not trainable, so it won't have .grad field. Not sure about other FSDP2 tests in torchao.

In the end, is correct to say that the bug is more about FSDP2+torch.compile(optim_step)? If it is not isolated to custom optimizer, perhaps we can add some tests for this scenario in PyTorch core or other repos too.

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 13, 2024

In the end, is correct to say that the bug is more about FSDP2+torch.compile(optim_step)? If it is not isolated to custom optimizer, perhaps we can add some tests for this scenario in PyTorch core or other repos too.

yeah, I could definitely believe that this is true (I don't have bandwidth to add those tests, but if someone wants to try making a smaller repro that doesn't use your low bit optimizer they are welcome to 😄 )

then again, I think this is a pretty one-off bug that we just expected to be very rarely hit (we haven't had to excercise a lot of code in compile where our tensor inputs to the graph also have .grad fields that are subclasses), that should have a relatively straightforward fix.

@gau-nernst
Copy link
Collaborator

gau-nernst commented Aug 13, 2024

I can confirm normal optimizers have this bug too

import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer

batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
    n_layers=3,
    n_heads=4,
    dim=1024,
    vocab_size=vocab_size,
    max_seq_len=seq_len,
)
model = Transformer(model_args).cuda()

for m in model.layers:
    fully_shard(m)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, foreach=False, fused=False)

# compile optimizer
optim.step = torch.compile(optim.step)

for iter_idx in range(5):
    inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
    model(inp).mean().backward()
    optim.step()
    optim.zero_grad()

Run with torchrun --nnodes 1 --nproc_per_node 1 debug.py

Agree with your last point 😄! Hopefully the fix in PyTorch core is coming soon! Thank you for the help!

@gau-nernst
Copy link
Collaborator

@bdhirsh I noticed that the FSDP test for low-bit optim now passed with torch nightly. Was it fixed in core recently? I didn't see any updates in pytorch/pytorch#133274

@bdhirsh
Copy link
Contributor

bdhirsh commented Sep 5, 2024

hmm that's strange - i ran the non-subclass repro you put above locally and it still fails for me:

import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer

batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
    n_layers=3,
    n_heads=4,
    dim=1024,
    vocab_size=vocab_size,
    max_seq_len=seq_len,
)
model = Transformer(model_args).cuda()

for m in model.layers:
    fully_shard(m)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, foreach=False, fused=False)

# compile optimizer
optim.step = torch.compile(optim.step)

for iter_idx in range(5):
    inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
    model(inp).mean().backward()
    optim.step()
    optim.zero_grad()

@gau-nernst
Copy link
Collaborator

Hmm, I think it might be because I change the way I compile the optim step. Now I static-shape compile optim step for each param, instead of optim step for all params #812. In that case the issue in pytorch core is still there, but we can probably close this issue?

@bdhirsh
Copy link
Contributor

bdhirsh commented Sep 5, 2024

ah yeah, great - this is definitely just a bug at the intersection of subclasses + dynamic shapes + optimizer/gradient, so if you're ok with static shapes only for now (which might be better for perf anyway), closing this issue sounds fine to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants