-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Improve performance of load_state_dict
#37902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve performance of load_state_dict
#37902
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
Rocketknight1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after review - pinging @ArthurZucker for core maintainer review and because he recently updated this code
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! about how much of an increase on load speed did you get with this? 🤗
|
In my case the time to load Unsloth's Qwen3-30B-A3B reduced from 15 min to 2 min. But loading a dense model of the same size takes only seconds, and only one core is used when loading, so maybe it's still worth considering how to further speed up this. |
|
@woct0rdho it could definitely be worth optimizing this further. In the past, most models were either dense or MoEs with a small number of experts (e.g. Mixtral-8x7B), and so the slowdown here wasn't very important. However, I expect that models like Qwen3 or Deepseek-V3 will be more common in future, with a huge number of experts but only a small number of activated experts. These models combine high capacity with fast inference, which is very important for "reasoning" training pipelines that involve RL. On a fast SSD, a 30B model should not take 2 minutes to load, so we should consider parallel weight loading or other speedups! cc @Narsil, is there a recommended way to speed up multiple calls to |
|
Hi @woct0rdho, we're investigating the slow loading issue. Can you give us more details on the disc you're loading from? For example, is it a local NVMe drive or a remote mount? |
|
@woct0rdho Are you running on a mounted network disk by any chance ? Mounted network disks do no play well with memory mapping, and every read incurs a quite large overhead (depending on how it's mounted you can see the sort of latencies you see here). On a AWS NVMe the model loads in ~12s for me (4xL4). I created a small repro script outside of transformers: from huggingface_hub import hf_hub_download
import torch
from safetensors.torch import load_file
import datetime
torch.zeros((2, 2)).cuda() # Initialize cuda runtime to skip that in measurements
filenames = []
for i in range(16):
rfilename = f"model-{i + 1:05d}-of-00016.safetensors"
filename = hf_hub_download("Qwen/Qwen3-30B-A3B", rfilename)
filenames.append(filename)
start = datetime.datetime.now()
all_data = {}
for i, filename in enumerate(filenames):
data = load_file(filename, device="cpu")
all_data.update(data)
print(f"CPU Load Took {datetime.datetime.now() - start}")
start = datetime.datetime.now()
all_data = {}
for i, filename in enumerate(filenames):
data = load_file(filename, device=f"cuda:{i % 4}")
all_data.update(data)
print(f"Took {datetime.datetime.now() - start}")Results (4xL4) The machinery in transfomers, follows pretty closely what Sidenote: HDD can also suffer from slow reads, and that can be fixed by changing |
|
Good point. I was running it on AutoDL (a cloud GPU provider). There is a folder Update: I modified your repro script a bit to load only 4 files on 1 GPU: #!/usr/bin/env python3
import os
os.environ["HF_HOME"] = "/root/autodl-tmp/.cache"
import datetime
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
torch.zeros((2, 2)).cuda() # Initialize cuda runtime to skip that in measurements
filenames = []
for i in range(4):
rfilename = f"model-{i + 1:05d}-of-00016.safetensors"
filename = hf_hub_download("Qwen/Qwen3-30B-A3B", rfilename)
# rfilename = f"model-{i + 1:05d}-of-00004.safetensors"
# filename = hf_hub_download("unsloth/Qwen3-30B-A3B-bnb-4bit", rfilename)
filenames.append(filename)
start = datetime.datetime.now()
all_data = {}
for i, filename in enumerate(filenames):
data = load_file(filename, device="cpu")
all_data.update(data)
print(f"CPU Load Took {datetime.datetime.now() - start}")
start = datetime.datetime.now()
all_data = {}
for i, filename in enumerate(filenames):
data = load_file(filename, device="cuda:0")
all_data.update(data)
print(f"GPU Load Took {datetime.datetime.now() - start}")The results are: |
Improve performance of load_state_dict
What does this PR do?
We avoid executing
get_sliceunlessmap_location == "meta"to improve the performance when loading a model with a large number of tensors.Even though we avoid the dtype check in Python, the dtype will be checked at https://github.com/huggingface/safetensors/blob/7d5af853631628137a79341ddc5611d18a17f3fe/bindings/python/src/lib.rs#L1186
Fixes #37887
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Rocketknight1