Skip to content
Merged
Show file tree
Hide file tree
Changes from 85 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
2cfe2db
Feat: initial conversion tool draft
S1ro1 Feb 11, 2025
d055eab
Feat: add value mapping to conversion tool
S1ro1 Feb 11, 2025
23ea666
Refactor: move from os to pathlib
S1ro1 Feb 11, 2025
7deaaaa
Feat: add first tests
S1ro1 Feb 11, 2025
d506b95
Feat: more tests
S1ro1 Feb 11, 2025
3eab226
Feat: minor fixes + dataclass conversions
S1ro1 Feb 11, 2025
269c217
Feat: more remapping
S1ro1 Feb 12, 2025
ba5372c
Fix: namespace has no attribute version + style
S1ro1 Feb 12, 2025
b16131b
Fix: offload params behavior
S1ro1 Feb 12, 2025
b0cc66b
Feat: add option to only rename keys in the config file to
S1ro1 Feb 12, 2025
8bc5cb3
Fix: wrong attr name
S1ro1 Feb 12, 2025
c31cc55
Merge branch 'main' into dev/fsdp2
S1ro1 Feb 12, 2025
bc70ec2
Fix: partially resolve comments
S1ro1 Feb 13, 2025
00dafc4
Feat: work on config command + minor fixes to reflect changes
S1ro1 Feb 13, 2025
7a92dac
Refactor: style + quality
S1ro1 Feb 13, 2025
b724f9c
Feat: fsdp2 initial work
S1ro1 Feb 13, 2025
7cd2587
Feat: some cleanups and first running fsdp2
S1ro1 Feb 13, 2025
d920e94
Fix: version checks + mixed precision policy
S1ro1 Feb 17, 2025
3cc7c20
Refactor: style + quality
S1ro1 Feb 17, 2025
432b4ff
Remove obsolete todos
S1ro1 Feb 17, 2025
bb83985
Feat: grad norm clipping
S1ro1 Feb 17, 2025
7759d79
Fix: tests + rename attrs
S1ro1 Feb 17, 2025
f8bfa04
Refactor: style + quality
S1ro1 Feb 17, 2025
47a12b9
Fix: None object is not iterable
S1ro1 Feb 17, 2025
030fa5d
Fix: default cpu_offload for fsdp2
S1ro1 Feb 18, 2025
c5526aa
Fix: cpu offload now behaves correctly
S1ro1 Feb 26, 2025
4984c8a
Feat: apply_activation_checkpointing
S1ro1 Feb 26, 2025
914e55f
Fix: append to models
S1ro1 Mar 3, 2025
1304284
Feat: start on concept guide
S1ro1 Mar 6, 2025
f9061f6
Merge branch 'main' into dev/fsdp2
S1ro1 Mar 6, 2025
c596c12
wip: concept guide
S1ro1 Mar 6, 2025
d55bab7
Fix: toctree
S1ro1 Mar 6, 2025
8494f7b
cleanup of the concept guide
S1ro1 Mar 7, 2025
f298b5e
Fix: minor fixes + mp
S1ro1 Mar 10, 2025
4e4b664
Fix: quality + | to union
S1ro1 Mar 10, 2025
30961c9
Feat: backwards compatibility + args cleanup
S1ro1 Mar 11, 2025
314df77
Fix: style + quality
S1ro1 Mar 11, 2025
5238cb2
Feat: enable dropping refs when getting named params
S1ro1 Mar 14, 2025
e104c10
Merge branch 'main' into dev/fsdp2
S1ro1 Mar 14, 2025
8347b91
Fix: memory footprint with fsdp2
S1ro1 Mar 14, 2025
70f4882
Feat: cpu ram efficient loading
S1ro1 Mar 14, 2025
c597598
Fix: mp
S1ro1 Mar 17, 2025
fa73295
Fix: not warn about sync_modules if fsdp version is 1
S1ro1 Mar 17, 2025
ee1267a
Refactor: minor changes
S1ro1 Mar 17, 2025
021e153
Small fixes + refactors
S1ro1 Mar 18, 2025
04a5ffb
Feat: docs + cleanup
S1ro1 Mar 18, 2025
d9fba39
Feat: saving works (not sure about optim)
S1ro1 Mar 18, 2025
0696c96
More loading/saving work
S1ro1 Mar 19, 2025
0d0abde
Feat: disable local_state_dict for fsdp2
S1ro1 Mar 19, 2025
1fe97f6
Fix: fsdp2 convergence
S1ro1 Mar 21, 2025
ab4e58d
Merge branch 'main' into dev/fsdp2
S1ro1 Mar 21, 2025
d6edd8c
Feat: working comparison script
S1ro1 Mar 21, 2025
8a48038
Feat: memory tracking fsdp2
S1ro1 Mar 21, 2025
a563493
Feat: memory visualizer
S1ro1 Mar 21, 2025
a8d127b
Feat: more work on benchmark
S1ro1 Mar 21, 2025
5560555
Fix: raise error if model+optimizer arent prepared together
S1ro1 Mar 21, 2025
ab5a6cf
Minor fixes
S1ro1 Mar 21, 2025
31ce351
Style
S1ro1 Mar 21, 2025
b619589
More warnings
S1ro1 Mar 21, 2025
4dab060
Fix: reshard_after_forward vs sharding_strategy conflict
S1ro1 Mar 21, 2025
c7ef53d
Refactor: clean up accelerator
S1ro1 Mar 22, 2025
8fadfbf
Feat: more testing in fsdp2 benchmark
S1ro1 Mar 22, 2025
79d5cf7
Fix: memory visualizer
S1ro1 Mar 22, 2025
d633c12
Untested: support load/save_state
S1ro1 Mar 22, 2025
8a4de4d
Feat: concept guide improvements
S1ro1 Mar 22, 2025
4881bbf
Refactor: concept guide
S1ro1 Mar 24, 2025
9fd8fd5
Feat: benchmark works
S1ro1 Mar 24, 2025
0d53cad
Feat: more work on fsdp2 benchmark
S1ro1 Mar 25, 2025
9251dcd
Fix: note syntax
S1ro1 Mar 25, 2025
25f15c7
Fix: small fixes + make original tests work
S1ro1 Mar 25, 2025
4dfb12c
Fix: grad scaling
S1ro1 Mar 25, 2025
d0244e2
Feat: reshard after forward tests
S1ro1 Mar 25, 2025
d2891a2
Feat: backward prefetch tests
S1ro1 Mar 25, 2025
f99e980
Feat: tests for fsdp2
S1ro1 Mar 26, 2025
5619fea
Refactor: minor fixes
S1ro1 Mar 26, 2025
3fba7e2
Feat: fsdp_utils docstrings
S1ro1 Mar 26, 2025
6df6483
Feat: autodoc fsdp.md
S1ro1 Mar 26, 2025
d24aa9d
Docs: get_module_children_bottom_up
S1ro1 Mar 26, 2025
050488d
Fix: remove unused images
S1ro1 Mar 26, 2025
e984571
Refactor: benchmark cleanup
S1ro1 Mar 26, 2025
ccf3d42
Fix: docs
S1ro1 Mar 26, 2025
124a56f
Feat: final doc changes
S1ro1 Mar 26, 2025
5bb4f30
Fix: torch.distributed has no attribute tensor
S1ro1 Mar 26, 2025
0f71c8c
Fix: style
S1ro1 Mar 26, 2025
0a2791b
Feat: tests include version in failures
S1ro1 Mar 26, 2025
9f37abd
Fix: benchmark force model to load in fp32
S1ro1 Mar 26, 2025
a7333d1
Fix: rename runs
S1ro1 Mar 27, 2025
4d4f829
Feat: last minor fixes
S1ro1 Mar 27, 2025
7ad58a1
Feat: new benchmark images
S1ro1 Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions benchmarks/fsdp2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# FSDP2 Benchmarks

This benchmark showcases `FSDP2` in 🤗 `accelerate` and compares it to `torch` baseline.

## Overview

This benchmark consists of two parts:
- `main.py` is the main script that runs the benchmark
- `visualize.py` is the script that visualizes the results

## Motivation

We want to showcase that 🤗 `accelerate`'s integration of `FSDP2` is on par raw PyTorch, and highlight a "broken" part in PyTorch that creating an optimizer before applying `FSDP2` **doesn't result in a working training loop**. (more on this later)
This script showcases **matching memory usage and convergence between `accelerate` and `torch`'s baseline.**
To deal with this breaking change (and maintain backward compatibility with FSDP1 in terms of an API), `accelerate` had to come up with a workaround since `accelerate` assumes that the user will nearly always create a model, optimizer, scheduler, etc beforehand and bring them themselves. This lead to an issue of a stark increase in memory as well as the model not even training if the user creates an optimizer beforehand.
To workaround this, we replace the parameters inside the optimizer with the newly created FSDP2 sharded ones. More about this can be found in this [blog post (TBD)](TODO)
> [!WARNING]
> This script is intended to fit on 2x 24GB GPUs, though on so few GPUs it's not possible to see the memory difference (discrepancies in grad allocation result in lower memory usage in the non-fixed case), only the difference in convergence. Below are attached results from 8x H100 GPUs where the difference is visible.
> TLDR: more GPUs = bigger memory difference between fixed and non-fixed cases.

## Results

Here are the results from running the benchmark on 8x H100 GPUs:

<p align="center">
<img src="imgs/allocated_memory.png" width="80%" alt="Allocated Memory Usage">
</p>
<p align="center">
<img src="imgs/reserved_memory.png" width="80%" alt="Reserved Memory Usage">
</p>

As you can see, the memory usage of `accelerate` and `torch_post_shard` (the **intended** way) are very similar, while `torch_pre_shard_not_fixed` uses significantly more memory. Our fix in `torch_pre_shard_fixed` brings the memory usage back in line with the **intended** approach.

> [!WARNING]
> Timing discrepancies are due to the benchmarks being ran in 1 script.


## Running

To run the benchmark, you can either use `accelerate launch` or `torchrun` (with the appropriate params):
```bash
accelerate launch main.py
```
```bash
# For two GPUs
torchrun --nproc_per_node 2 main.py
```

This supports multiple configurable options, you can learn about them by running:
```bash
python3 main.py --help
```

This script will run 4 different benchmarks:
- `torch_post_shard`: `torch` baseline where optimizer is created after applying `FSDP2`, this is the **intended** way to do it
- `torch_pre_shard_not_fixed`: `torch` baseline where optimizer is created before applying `FSDP2`
- `torch_pre_shard_fixed`: `torch` baseline where optimizer is created before applying `FSDP2` but we apply our fix to the optimizer
- `accelerate`: `accelerate`'s own integration of `FSDP2` where optimizer is created before applying `FSDP2`, but we apply our fix to the optimizer

Memory results are saved in a folder specified by `--output_dir` argument.
Optionally, you can specify `--save_memory_snapshot` to save the torch memory snapshot, which can then be viewed using [`torch memory viz`](https://pytorch.org/memory_viz)

## Visualizing results

To visualize the results, you can run:

```bash
python3 visualize.py --dir <path_to_output_dir>
```

This will then create two plots, showcasing allocated and reserved memory usage between all the different benchmarks discussed above.



Binary file added benchmarks/fsdp2/imgs/allocated_memory.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added benchmarks/fsdp2/imgs/reserved_memory.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
121 changes: 121 additions & 0 deletions benchmarks/fsdp2/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Callable

import torch

from accelerate import Accelerator
from utils import parse_args, prepare_accelerate, prepare_torch


MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
LEARNING_RATE = 3e-5

CONFIG = {
"model_name": MODEL_NAME,
"learning_rate": LEARNING_RATE,
}


def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_dataloader: torch.utils.data.DataLoader,
accelerator: Accelerator,
) -> torch.Tensor:
losses = []
for batch in train_dataloader:
optimizer.zero_grad()
outputs = model(**batch, use_cache=False)

loss = outputs.loss
losses.append(loss.item())
accelerator.backward(loss)
optimizer.step()

return torch.tensor(losses)


def evaluate(args, config: dict, init_fn: Callable, run_name: str) -> torch.Tensor:
model, optimizer, dataloader, accelerator, memory_tracker = init_fn(args, config)

loss = train(model, optimizer, dataloader, accelerator)

memory_tracker.stop()
msg = f"""Results for {run_name} (rank 0):
Loss: {loss[-1].item()}
Peak Allocated Memory: {float(memory_tracker.peak_allocated_memory):.2f} MB
Peak Reserved Memory: {float(memory_tracker.peak_reserved_memory):.2f} MB
{'-' * 34}"""
accelerator.print(msg)
return loss


def main():
args = parse_args()
evaluations = [
functools.partial(
evaluate,
init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=True),
run_name="Optimizer Pre fully_shard (w/ fix)",
),
functools.partial(
evaluate,
init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=False),
run_name="Optimizer Pre fully_shard (w/o fix)",
),
functools.partial(
evaluate,
init_fn=functools.partial(prepare_torch, post_shard_optimizer=True),
run_name="Optimizer Post fully_shard",
),
functools.partial(evaluate, init_fn=prepare_accelerate, run_name="Accelerate"),
]
labels = [
"Optimizer Pre fully_shard (w/ fix)",
"Optimizer Post fully_shard (w/o fix)",
"Optimizer Post fully_shard",
"Accelerate",
]

results = {}

for evaluation, label in zip(evaluations, labels):
results[label] = evaluation(args, CONFIG)

torch.testing.assert_close(
results["Optimizer Post fully_shard"],
results["Optimizer Pre fully_shard (w/ fix)"],
msg="Optimizer Post fully_shard and Optimizer Pre fully_shard (w/ fix) should be the same",
)

torch.testing.assert_close(
results["Optimizer Post fully_shard"],
results["Accelerate"],
msg="Optimizer Post fully_shard and Accelerate should be the same",
)

torch.testing.assert_close(
results["Accelerate"],
results["Optimizer Pre fully_shard (w/ fix)"],
msg="Accelerate and Optimizer Pre fully_shard (w/ fix) should be the same",
)

torch.distributed.destroy_process_group()


if __name__ == "__main__":
main()
128 changes: 128 additions & 0 deletions benchmarks/fsdp2/measure_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import json
import os
import threading
import time

import psutil
import torch

from accelerate import PartialState


class MemoryTracker:
def __init__(
self,
device: torch.device,
output_directory: str,
run_name: str,
save_memory_snapshot: bool,
log_interval: float = 0.01,
):
"""Class for tracking gpu and cpu memory usage of the process.

Args:
device (`torch.device`):
Cuda device to monitor.
output_directory (`str`):
Directory to save the memory usage data to, will be created if it doesn't exist.
run_name (`str`):
Name of the run, will be used to name the output files.
save_memory_snapshot (`bool`):
Whether to also save `torch.cuda.memory._dump_snapshot` to the output directory.
log_interval (`float`, *optional*):
Interval in seconds between memory measurements. Defaults to 0.01.
"""
self.log_interval = log_interval
self.save_memory_snapshot = save_memory_snapshot
self.output_directory = output_directory
self.run_name = run_name

self.timestamps = []
self.allocated_memory = []
self.reserved_memory = []
self.virtual_memory = []

self.start_time = None
self.running = False

self._thread = None
self._state = PartialState()
self._process = psutil.Process()
self._devicee = device

def _monitor(self):
self.start_time = time.time()

while self.running:
allocated = torch.cuda.memory_allocated(self._devicee) / (1024 * 1024)
reserved = torch.cuda.memory_reserved(self._devicee) / (1024 * 1024)
virtual_memory = self._process.memory_info().rss / (1024 * 1024)

self.allocated_memory.append(allocated)
self.reserved_memory.append(reserved)
self.virtual_memory.append(virtual_memory)
self.timestamps.append(time.time() - self.start_time)

time.sleep(self.log_interval)

def start(self):
gc.collect()
torch.cuda.empty_cache()

os.makedirs(self.output_directory, exist_ok=True)

if self.save_memory_snapshot:
torch.cuda.memory._record_memory_history()

self.running = True
self._thread = threading.Thread(target=self._monitor)
self._thread.daemon = True
self._thread.start()

def stop(self):
self.running = False
if self._thread:
self._thread.join()

if self.save_memory_snapshot and self._state.is_main_process:
output_file = os.path.join(self.output_directory, f"{self.run_name}_memory_snapshot.pkl")
torch.cuda.memory._dump_snapshot(output_file)

if self._state.is_main_process:
path = os.path.join(self.output_directory, f"{self.run_name}_memory_usage.json")
with open(path, "w") as f:
json.dump(
{
"timestamps": self.timestamps,
"allocated_memory": self.allocated_memory,
"reserved_memory": self.reserved_memory,
"virtual_memory": self.virtual_memory,
},
f,
)

torch.cuda.memory._record_memory_history(False)
torch.cuda.empty_cache()

@property
def peak_allocated_memory(self):
return max(self.allocated_memory)

@property
def peak_reserved_memory(self):
return max(self.reserved_memory)
Loading
Loading