Skip to content

Initial FSDP2 support#3394

Merged
muellerzr merged 89 commits intohuggingface:mainfrom
S1ro1:dev/fsdp2
Mar 27, 2025
Merged

Initial FSDP2 support#3394
muellerzr merged 89 commits intohuggingface:mainfrom
S1ro1:dev/fsdp2

Conversation

@S1ro1
Copy link
Contributor

@S1ro1 S1ro1 commented Feb 11, 2025

FSDP2

What does this add?

This PR adds support for enabling FSDP2 in Accelerate as another way to parallelize your training 🚀

Why is it needed?

FSDP2 offers a lot of advantages over current FSDP1 implementation, mostly not user facing, but this enables further development of features in Accelerate such as combining TP + FSDP.
For users, this enables small improvements in memory usage

What parts of the API does this impact?

User-facing:

Allows users to specify --fsdp-version=2 to use FSDP2 instead of current FSDP1, this also changes some of the configuration options for FullyShardedDataParallelPlugin. Old configurations are still supported as before.

Basic Usage Example(s):

from accelerate import FullyShardedDataParallelPlugin, Accelerator

fsdp_plugin = FullyShardedDataParallelPlugin(
    fsdp_version=2
    # other options...
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

When would I use it, and when wouldn't I?

Using FSDP2 as a default is reccomended, with swapping to FSDP1 if something isn't supported.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! Left some initial comments

self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()]
if self.fsdp_version != 2 and isinstance(self.reshard_after_forward, bool):
raise ValueError(
"reshard_after_forward set to bool. This is not supported in FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"reshard_after_forward set to bool. This is not supported in FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"
f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"

Comment on lines +1653 to +1660
if self.fsdp_version != 2 and isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `torch.distributed.fsdp.CPUOffloadPolicy`. This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`"
)
if self.fsdp_version == 2 and not isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `bool` or `torch.distributed.fsdp.CPUOffload`. This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we can simplify this

Suggested change
if self.fsdp_version != 2 and isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `torch.distributed.fsdp.CPUOffloadPolicy`. This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`"
)
if self.fsdp_version == 2 and not isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `bool` or `torch.distributed.fsdp.CPUOffload`. This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`"
)
if isinstance(self.cpu_offload, CPUOffloadPolicy):
err = "`cpu_offload` set to `torch.distributed.fsdp.CPUOffloadPolicy`."
if self.fsdp_version != 2:
raise ValueError(f"{err} This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`")
else:
raise ValueError(f"{err} This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`")

self.cpu_offload = CPUOffload(offload_params=self.cpu_offload)
if self.fsdp_version == 2:
if not self.cpu_offload:
warnings.warn("Offload_params is set to False, however FSDP2 always offloads parameters and runs optimizer step on CPU. This will be overridden to True.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in general, users will ignore warnings (and will get annoyed it bloats the logs). So we should instead use logger.warn if we don't want to explicitly raise an error about this.

if self.use_orig_params is None:
self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1
if self.fsdp_version == 2 and self.use_orig_params is not None:
warnings.warn("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing many warning.warn,s let's accumulate them all and do one big logger.warn at the end

Comment on lines +1719 to +1722
if self.fsdp_version == 2 and self.forward_prefetch is not None:
raise ValueError(
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version` set to 1"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.fsdp_version == 2 and self.forward_prefetch is not None:
raise ValueError(
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version` set to 1"
)
if self.fsdp_version == 2 and self.forward_prefetch is not None:
raise ValueError(
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`"
)

"""
Validates the mixed precision policy, abstracted away to not bring in the imports if not needed.
"""
from torch.distributed.fsdp import MixedPrecision, MixedPrecisionPolicy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this import lead to issues on old pytorch versions?

@kmehant
Copy link
Contributor

kmehant commented Feb 13, 2025

@muellerzr @S1ro1

I see potential collaboration on this thread with my PRs

  1. Support TP + FSDPv2 / HSDP or just FSDPv2 / HSDP #3395
  2. [RFC] Support FSDP2 #3231

With changes on this PR I understand that the design is to convert FSDP1 args to FSDP2 args and put to use. With my initial discussions started at #3231, In the PR #3395 I propose to have FSDP2 as a separate distributed type training protocol alongsides FSDP1. Also, to support more complex combinations of parallelisms I as well propose a design as described in the PR's description.

Looking forward.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good breakdown of their differences, and the table at the end really helps with the comparison as well! 👏 My only feedback is:

The main feature of FSDP2 is DTensor, but I didn't really realize it's significance until the following section that illustrated the limitation of FSDP1. Now that first sentence "Simpler internal implementation, where each Parameter is a separate DTensor" is more impactful to me.

I wonder if it'd be better to start with that section so users understand what DTensor is and how it enables FSDP2 to overcome the limitations of FSDP1. Then you can discuss the other benefits of FSDP2 in the following section.

Each Parameter of the original `Layer` is sharded across the 0th dimension, and split between 2 GPUs. Now, each `Linear` layer is a separate `DTensor` and storing metadata per-parameter is possible and straightforward.


> [!NOTE]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc-builder only supports TIP or WARNING

Suggested change
> [!NOTE]
> [!TIP]

@SunMarc
Copy link
Member

SunMarc commented Mar 26, 2025

cc @a-r-r-o-w

@S1ro1 S1ro1 changed the title WIP: Initial FSDP2 support Initial FSDP2 support Mar 26, 2025
@S1ro1 S1ro1 marked this pull request as ready for review March 26, 2025 17:04
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this huge work ! Left a couple of comments

### If using YAML config:
Use our conversion tool:
```bash
accelerate to-fsdp2 --config_file config.yaml --output_file new_config.yaml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea but maintaining that might be a bit of a pain. You already have added the deprecation msg for some args + not sure what will happen with this tool you will fully deprecate fsdp_sharding_strategy. wdyt @muellerzr ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now we should maintain until PyTorch is fully removed with the old or accelerate 2.0 IMO. It's fully breaking on the pytorch side and that hasn't sat well with me

@muellerzr muellerzr merged commit d7c741a into huggingface:main Mar 27, 2025
25 checks passed
Comment on lines +1402 to +1406
if (model_count < 1 and optimizer_count > 0) or (model_count > 0 and optimizer_count < 1):
raise ValueError(
"When using FSDP2, a model and optimizer must be passed together to `Accelerator.prepare()`"
" as the optimizer needs to have its parameters modified after the model is converted."
)
Copy link
Contributor

@kmehant kmehant Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this work! Couple of comments/questions.

Wouldn't this approach conflict with the way transformers does it with delay_optimizer_creation (https://github.com/huggingface/transformers/blob/348f3285c5114159d2ff4933b4b8ae36866d01a7/src/transformers/trainer.py#L2314) knob?

Use of transformers would need dropping use of this knob for fsdpv2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly, there would be no need for this knob for FSDP2 anymore. I don't think that it is a breaking change, rather an improvement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this. I'm guessing transformers will require changes to get this working with Trainer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, if you do create a PR feel free to mention me, if not I'll take a look next week ish.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I actually do need to get Trainer to work with FSDP2, so if I have time I might take a look over the weekend if I have time. I'm not sure how much effort getting everything together would be.

Hopefully it can literally be a 1 line change in transformers, but I ran into some OOM issue when I tried hacking it together, so maybe it won't be that easy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@byi8220 I haven't done any research/work on using LoRA yet. This PR was just an initial work that supports simple FSDP2, I'll try to check how it interacts with LORA in the near future. Thanks for the checks on that. I think that preliminary merging your PR for transformers integration is fine, as it should support all that is currently supported by Accelerate, so I approved for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. I'll try to find a second GPU and run the tests before I officially raise the PR.

At worst this shouldn't break anything, since this is a no-op unless the user explicitly opts in.

I'll try to check how it interacts with LORA in the near future. Thanks for the checks on that.

I've mostly been hacking around, and I'm pretty unfamiliar with FSDP, so there is a high chance I am misinterpreting, or that I am not setting up my trainer properly. However everything I am suspecting as memory wastage seems logical to me (need to wrap separately, need to be clever with mmap). I might file an issue if I can make a clean reproduction

Copy link
Contributor Author

@S1ro1 S1ro1 Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest trying to do LORA with accelerate only first, do trainer after. As of the mmap stuff: getting that correctly has been a hassle in my limited experience with this bug/feature. Again, would leave it out of this and try to make simplest reproduction. If you won't come up with anything I'll try to look at this ~next week.

My guess is maybe we have to wrap the lora modules in isolation, and replicate what lora_fsdp_wrap_policy is doing?

I'd assume this is not the case with FSDP2, FSDP2 also known as per-parameter sharding allows each parameter to have its own metadata and allocates for those (I'm 99% sure) - you can check more about this here. Therefore this is not needed as sharding "only" influences the distributed semantics. I'm pretty certain this is needed only for FSDP1. For FSDP2 we can just shard as we would usually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to find a second GPU and run the tests before I officially raise the PR.

PR was raised, passed 1/2 new tests. The other test had to be skipped (its FSDP 1 counterpart is skipped)

I'd suggest trying to do LORA with accelerate only first, do trainer after.

SG, I'll try to get something small and raise a bug for it soon. Getting fsdp2+lora in trainer working is a high priority blocker for me, so I will look into this tomorrow

For FSDP2 we can just shard as we would usually.

Makes sense, I experimented with this a bit and found no savings.

As of the mmap stuff: getting that correctly has been a hassle in my limited experience with this bug/feature.

I can relate...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed a bug about this, #3474

FSDP_PYTORCH_VERSION = (
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
)
FSDP2_PYTORCH_VERSION = "2.5.1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this 2.6.0?

Since imports such as

from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy
would break for any version below 2.6.0

Comment on lines +396 to +397
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this check, you might also want to protect imports since the imports (such as

from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy
) seem to break first even before we hit this condition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the sanity checks. Will do more thorough version checks in the follow-up PRs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @S1ro1, appreciate your review to PR #3499 since I needed these on my setup. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants