-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Initial FSDP2 support #3394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Initial FSDP2 support #3394
Changes from 85 commits
Commits
Show all changes
89 commits
Select commit
Hold shift + click to select a range
2cfe2db
Feat: initial conversion tool draft
S1ro1 d055eab
Feat: add value mapping to conversion tool
S1ro1 23ea666
Refactor: move from os to pathlib
S1ro1 7deaaaa
Feat: add first tests
S1ro1 d506b95
Feat: more tests
S1ro1 3eab226
Feat: minor fixes + dataclass conversions
S1ro1 269c217
Feat: more remapping
S1ro1 ba5372c
Fix: namespace has no attribute version + style
S1ro1 b16131b
Fix: offload params behavior
S1ro1 b0cc66b
Feat: add option to only rename keys in the config file to
S1ro1 8bc5cb3
Fix: wrong attr name
S1ro1 c31cc55
Merge branch 'main' into dev/fsdp2
S1ro1 bc70ec2
Fix: partially resolve comments
S1ro1 00dafc4
Feat: work on config command + minor fixes to reflect changes
S1ro1 7a92dac
Refactor: style + quality
S1ro1 b724f9c
Feat: fsdp2 initial work
S1ro1 7cd2587
Feat: some cleanups and first running fsdp2
S1ro1 d920e94
Fix: version checks + mixed precision policy
S1ro1 3cc7c20
Refactor: style + quality
S1ro1 432b4ff
Remove obsolete todos
S1ro1 bb83985
Feat: grad norm clipping
S1ro1 7759d79
Fix: tests + rename attrs
S1ro1 f8bfa04
Refactor: style + quality
S1ro1 47a12b9
Fix: None object is not iterable
S1ro1 030fa5d
Fix: default cpu_offload for fsdp2
S1ro1 c5526aa
Fix: cpu offload now behaves correctly
S1ro1 4984c8a
Feat: apply_activation_checkpointing
S1ro1 914e55f
Fix: append to models
S1ro1 1304284
Feat: start on concept guide
S1ro1 f9061f6
Merge branch 'main' into dev/fsdp2
S1ro1 c596c12
wip: concept guide
S1ro1 d55bab7
Fix: toctree
S1ro1 8494f7b
cleanup of the concept guide
S1ro1 f298b5e
Fix: minor fixes + mp
S1ro1 4e4b664
Fix: quality + | to union
S1ro1 30961c9
Feat: backwards compatibility + args cleanup
S1ro1 314df77
Fix: style + quality
S1ro1 5238cb2
Feat: enable dropping refs when getting named params
S1ro1 e104c10
Merge branch 'main' into dev/fsdp2
S1ro1 8347b91
Fix: memory footprint with fsdp2
S1ro1 70f4882
Feat: cpu ram efficient loading
S1ro1 c597598
Fix: mp
S1ro1 fa73295
Fix: not warn about sync_modules if fsdp version is 1
S1ro1 ee1267a
Refactor: minor changes
S1ro1 021e153
Small fixes + refactors
S1ro1 04a5ffb
Feat: docs + cleanup
S1ro1 d9fba39
Feat: saving works (not sure about optim)
S1ro1 0696c96
More loading/saving work
S1ro1 0d0abde
Feat: disable local_state_dict for fsdp2
S1ro1 1fe97f6
Fix: fsdp2 convergence
S1ro1 ab4e58d
Merge branch 'main' into dev/fsdp2
S1ro1 d6edd8c
Feat: working comparison script
S1ro1 8a48038
Feat: memory tracking fsdp2
S1ro1 a563493
Feat: memory visualizer
S1ro1 a8d127b
Feat: more work on benchmark
S1ro1 5560555
Fix: raise error if model+optimizer arent prepared together
S1ro1 ab5a6cf
Minor fixes
S1ro1 31ce351
Style
S1ro1 b619589
More warnings
S1ro1 4dab060
Fix: reshard_after_forward vs sharding_strategy conflict
S1ro1 c7ef53d
Refactor: clean up accelerator
S1ro1 8fadfbf
Feat: more testing in fsdp2 benchmark
S1ro1 79d5cf7
Fix: memory visualizer
S1ro1 d633c12
Untested: support load/save_state
S1ro1 8a4de4d
Feat: concept guide improvements
S1ro1 4881bbf
Refactor: concept guide
S1ro1 9fd8fd5
Feat: benchmark works
S1ro1 0d53cad
Feat: more work on fsdp2 benchmark
S1ro1 9251dcd
Fix: note syntax
S1ro1 25f15c7
Fix: small fixes + make original tests work
S1ro1 4dfb12c
Fix: grad scaling
S1ro1 d0244e2
Feat: reshard after forward tests
S1ro1 d2891a2
Feat: backward prefetch tests
S1ro1 f99e980
Feat: tests for fsdp2
S1ro1 5619fea
Refactor: minor fixes
S1ro1 3fba7e2
Feat: fsdp_utils docstrings
S1ro1 6df6483
Feat: autodoc fsdp.md
S1ro1 d24aa9d
Docs: get_module_children_bottom_up
S1ro1 050488d
Fix: remove unused images
S1ro1 e984571
Refactor: benchmark cleanup
S1ro1 ccf3d42
Fix: docs
S1ro1 124a56f
Feat: final doc changes
S1ro1 5bb4f30
Fix: torch.distributed has no attribute tensor
S1ro1 0f71c8c
Fix: style
S1ro1 0a2791b
Feat: tests include version in failures
S1ro1 9f37abd
Fix: benchmark force model to load in fp32
S1ro1 a7333d1
Fix: rename runs
S1ro1 4d4f829
Feat: last minor fixes
S1ro1 7ad58a1
Feat: new benchmark images
S1ro1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # FSDP2 Benchmarks | ||
|
|
||
| This benchmark showcases `FSDP2` in 🤗 `accelerate` and compares it to `torch` baseline. | ||
|
|
||
| ## Overview | ||
|
|
||
| This benchmark consists of two parts: | ||
| - `main.py` is the main script that runs the benchmark | ||
| - `visualize.py` is the script that visualizes the results | ||
|
|
||
| ## Motivation | ||
|
|
||
| We want to showcase that 🤗 `accelerate`'s integration of `FSDP2` is on par raw PyTorch, and highlight a "broken" part in PyTorch that creating an optimizer before applying `FSDP2` **doesn't result in a working training loop**. (more on this later) | ||
| This script showcases **matching memory usage and convergence between `accelerate` and `torch`'s baseline.** | ||
| To deal with this breaking change (and maintain backward compatibility with FSDP1 in terms of an API), `accelerate` had to come up with a workaround since `accelerate` assumes that the user will nearly always create a model, optimizer, scheduler, etc beforehand and bring them themselves. This lead to an issue of a stark increase in memory as well as the model not even training if the user creates an optimizer beforehand. | ||
| To workaround this, we replace the parameters inside the optimizer with the newly created FSDP2 sharded ones. More about this can be found in this [blog post (TBD)](TODO) | ||
| > [!WARNING] | ||
| > This script is intended to fit on 2x 24GB GPUs, though on so few GPUs it's not possible to see the memory difference (discrepancies in grad allocation result in lower memory usage in the non-fixed case), only the difference in convergence. Below are attached results from 8x H100 GPUs where the difference is visible. | ||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| > TLDR: more GPUs = bigger memory difference between fixed and non-fixed cases. | ||
|
|
||
| ## Results | ||
|
|
||
| Here are the results from running the benchmark on 8x H100 GPUs: | ||
|
|
||
| <p align="center"> | ||
| <img src="imgs/allocated_memory.png" width="80%" alt="Allocated Memory Usage"> | ||
| </p> | ||
| <p align="center"> | ||
| <img src="imgs/reserved_memory.png" width="80%" alt="Reserved Memory Usage"> | ||
| </p> | ||
|
|
||
| As you can see, the memory usage of `accelerate` and `torch_post_shard` (the **intended** way) are very similar, while `torch_pre_shard_not_fixed` uses significantly more memory. Our fix in `torch_pre_shard_fixed` brings the memory usage back in line with the **intended** approach. | ||
|
|
||
| > [!WARNING] | ||
| > Timing discrepancies are due to the benchmarks being ran in 1 script. | ||
|
|
||
|
|
||
| ## Running | ||
|
|
||
| To run the benchmark, you can either use `accelerate launch` or `torchrun` (with the appropriate params): | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ```bash | ||
| accelerate launch main.py | ||
| ``` | ||
| ```bash | ||
| # For two GPUs | ||
| torchrun --nproc_per_node 2 main.py | ||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ``` | ||
|
|
||
| This supports multiple configurable options, you can learn about them by running: | ||
| ```bash | ||
| python3 main.py --help | ||
| ``` | ||
|
|
||
| This script will run 4 different benchmarks: | ||
| - `torch_post_shard`: `torch` baseline where optimizer is created after applying `FSDP2`, this is the **intended** way to do it | ||
| - `torch_pre_shard_not_fixed`: `torch` baseline where optimizer is created before applying `FSDP2` | ||
| - `torch_pre_shard_fixed`: `torch` baseline where optimizer is created before applying `FSDP2` but we apply our fix to the optimizer | ||
| - `accelerate`: `accelerate`'s own integration of `FSDP2` where optimizer is created before applying `FSDP2`, but we apply our fix to the optimizer | ||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Memory results are saved in a folder specified by `--output_dir` argument. | ||
| Optionally, you can specify `--save_memory_snapshot` to save the torch memory snapshot, which can then be viewed using [`torch memory viz`](https://pytorch.org/memory_viz) | ||
|
|
||
| ## Visualizing results | ||
|
|
||
| To visualize the results, you can run: | ||
|
|
||
| ```bash | ||
| python3 visualize.py --dir <path_to_output_dir> | ||
| ``` | ||
|
|
||
| This will then create two plots, showcasing allocated and reserved memory usage between all the different benchmarks discussed above. | ||
|
|
||
|
|
||
|
|
||
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import functools | ||
| from typing import Callable | ||
|
|
||
| import torch | ||
|
|
||
| from accelerate import Accelerator | ||
| from utils import parse_args, prepare_accelerate, prepare_torch | ||
|
|
||
|
|
||
| MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" | ||
| LEARNING_RATE = 3e-5 | ||
|
|
||
| CONFIG = { | ||
| "model_name": MODEL_NAME, | ||
| "learning_rate": LEARNING_RATE, | ||
| } | ||
|
|
||
|
|
||
| def train( | ||
| model: torch.nn.Module, | ||
| optimizer: torch.optim.Optimizer, | ||
| train_dataloader: torch.utils.data.DataLoader, | ||
| accelerator: Accelerator, | ||
| ) -> torch.Tensor: | ||
| losses = [] | ||
| for batch in train_dataloader: | ||
| optimizer.zero_grad() | ||
| outputs = model(**batch, use_cache=False) | ||
|
|
||
| loss = outputs.loss | ||
| losses.append(loss.item()) | ||
| accelerator.backward(loss) | ||
| optimizer.step() | ||
|
|
||
| return torch.tensor(losses) | ||
|
|
||
|
|
||
| def evaluate(args, config: dict, init_fn: Callable, run_name: str) -> torch.Tensor: | ||
| model, optimizer, dataloader, accelerator, memory_tracker = init_fn(args, config) | ||
|
|
||
| loss = train(model, optimizer, dataloader, accelerator) | ||
|
|
||
| memory_tracker.stop() | ||
| msg = f"""Results for {run_name} (rank 0): | ||
| Loss: {loss[-1].item()} | ||
| Peak Allocated Memory: {float(memory_tracker.peak_allocated_memory):.2f} MB | ||
| Peak Reserved Memory: {float(memory_tracker.peak_reserved_memory):.2f} MB | ||
| {'-' * 34}""" | ||
| accelerator.print(msg) | ||
| return loss | ||
|
|
||
|
|
||
| def main(): | ||
| args = parse_args() | ||
| evaluations = [ | ||
| functools.partial( | ||
| evaluate, | ||
| init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=True), | ||
| run_name="Optimizer Pre fully_shard (w/ fix)", | ||
| ), | ||
| functools.partial( | ||
| evaluate, | ||
| init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=False), | ||
| run_name="Optimizer Pre fully_shard (w/o fix)", | ||
| ), | ||
| functools.partial( | ||
| evaluate, | ||
| init_fn=functools.partial(prepare_torch, post_shard_optimizer=True), | ||
| run_name="Optimizer Post fully_shard", | ||
| ), | ||
| functools.partial(evaluate, init_fn=prepare_accelerate, run_name="Accelerate"), | ||
| ] | ||
| labels = [ | ||
| "Optimizer Pre fully_shard (w/ fix)", | ||
| "Optimizer Post fully_shard (w/o fix)", | ||
| "Optimizer Post fully_shard", | ||
| "Accelerate", | ||
| ] | ||
|
|
||
| results = {} | ||
|
|
||
| for evaluation, label in zip(evaluations, labels): | ||
| results[label] = evaluation(args, CONFIG) | ||
|
|
||
| torch.testing.assert_close( | ||
| results["Optimizer Post fully_shard"], | ||
| results["Optimizer Pre fully_shard (w/ fix)"], | ||
| msg="Optimizer Post fully_shard and Optimizer Pre fully_shard (w/ fix) should be the same", | ||
| ) | ||
|
|
||
| torch.testing.assert_close( | ||
| results["Optimizer Post fully_shard"], | ||
| results["Accelerate"], | ||
| msg="Optimizer Post fully_shard and Accelerate should be the same", | ||
| ) | ||
|
|
||
| torch.testing.assert_close( | ||
| results["Accelerate"], | ||
| results["Optimizer Pre fully_shard (w/ fix)"], | ||
| msg="Accelerate and Optimizer Pre fully_shard (w/ fix) should be the same", | ||
| ) | ||
|
|
||
| torch.distributed.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import gc | ||
| import json | ||
| import os | ||
| import threading | ||
| import time | ||
|
|
||
| import psutil | ||
| import torch | ||
|
|
||
| from accelerate import PartialState | ||
|
|
||
|
|
||
| class MemoryTracker: | ||
| def __init__( | ||
| self, | ||
| device: torch.device, | ||
| output_directory: str, | ||
| run_name: str, | ||
| save_memory_snapshot: bool, | ||
| log_interval: float = 0.01, | ||
| ): | ||
| """Class for tracking gpu and cpu memory usage of the process. | ||
|
|
||
| Args: | ||
| device (`torch.device`): | ||
| Cuda device to monitor. | ||
| output_directory (`str`): | ||
| Directory to save the memory usage data to, will be created if it doesn't exist. | ||
| run_name (`str`): | ||
| Name of the run, will be used to name the output files. | ||
| save_memory_snapshot (`bool`): | ||
| Whether to also save `torch.cuda.memory._dump_snapshot` to the output directory. | ||
| log_interval (`float`, *optional*): | ||
| Interval in seconds between memory measurements. Defaults to 0.01. | ||
| """ | ||
| self.log_interval = log_interval | ||
| self.save_memory_snapshot = save_memory_snapshot | ||
| self.output_directory = output_directory | ||
| self.run_name = run_name | ||
|
|
||
| self.timestamps = [] | ||
| self.allocated_memory = [] | ||
| self.reserved_memory = [] | ||
| self.virtual_memory = [] | ||
|
|
||
| self.start_time = None | ||
| self.running = False | ||
|
|
||
| self._thread = None | ||
| self._state = PartialState() | ||
| self._process = psutil.Process() | ||
| self._devicee = device | ||
|
|
||
| def _monitor(self): | ||
| self.start_time = time.time() | ||
|
|
||
| while self.running: | ||
| allocated = torch.cuda.memory_allocated(self._devicee) / (1024 * 1024) | ||
| reserved = torch.cuda.memory_reserved(self._devicee) / (1024 * 1024) | ||
| virtual_memory = self._process.memory_info().rss / (1024 * 1024) | ||
|
|
||
| self.allocated_memory.append(allocated) | ||
| self.reserved_memory.append(reserved) | ||
| self.virtual_memory.append(virtual_memory) | ||
| self.timestamps.append(time.time() - self.start_time) | ||
|
|
||
| time.sleep(self.log_interval) | ||
|
|
||
| def start(self): | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| os.makedirs(self.output_directory, exist_ok=True) | ||
|
|
||
| if self.save_memory_snapshot: | ||
| torch.cuda.memory._record_memory_history() | ||
|
|
||
| self.running = True | ||
| self._thread = threading.Thread(target=self._monitor) | ||
| self._thread.daemon = True | ||
| self._thread.start() | ||
|
|
||
| def stop(self): | ||
| self.running = False | ||
| if self._thread: | ||
| self._thread.join() | ||
|
|
||
| if self.save_memory_snapshot and self._state.is_main_process: | ||
| output_file = os.path.join(self.output_directory, f"{self.run_name}_memory_snapshot.pkl") | ||
| torch.cuda.memory._dump_snapshot(output_file) | ||
|
|
||
| if self._state.is_main_process: | ||
| path = os.path.join(self.output_directory, f"{self.run_name}_memory_usage.json") | ||
| with open(path, "w") as f: | ||
| json.dump( | ||
| { | ||
| "timestamps": self.timestamps, | ||
| "allocated_memory": self.allocated_memory, | ||
| "reserved_memory": self.reserved_memory, | ||
| "virtual_memory": self.virtual_memory, | ||
| }, | ||
| f, | ||
| ) | ||
|
|
||
| torch.cuda.memory._record_memory_history(False) | ||
| torch.cuda.empty_cache() | ||
|
|
||
| @property | ||
| def peak_allocated_memory(self): | ||
| return max(self.allocated_memory) | ||
|
|
||
| @property | ||
| def peak_reserved_memory(self): | ||
| return max(self.reserved_memory) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.