-
Notifications
You must be signed in to change notification settings - Fork 31.7k
TP initialization module-by-module #35996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice nice
| for submodule in model.modules(): | ||
| full_tp_plan.update(getattr(submodule, "_tp_plan", {})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should only do this for PreTrainedModels no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured that maybe it would be a bit more future-proof to iterate over all modules (it's not costly) -- but can be changed for sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's only do Prtrained for now!
| for param, plan in full_tp_plan.items(): | ||
| # "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern | ||
| pattern = param.replace("*", "[0-9]+") | ||
| if re.search(pattern, parent_module_name): | ||
| current_module_plan = plan | ||
| break | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to iterate over the full tp_plan, but we should be re-creating the key instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tp_plan does not contain the full module names (usually it starts with "layers"), so to be general it's much easier to iterate over the keys instead of starting from the module name and trying to get the key of the tp_plan (because the prefixes of the tp_plan keys may change). Once again it's not costly at all since the tp_plan is very small
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmm agreed cost-wise, it's a tad of a waste! but no worries
| process_device = list(device_map.values())[0] | ||
| all_module_parameters_initialized = all( | ||
| m.device == process_device for m in parent_module.parameters(recurse=False) | ||
| ) and all(m.device == process_device for m in parent_module.buffers(recurse=False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly this might be a tad bit costly for MOE for example / not necessarily needed.
We can either:
- maybe load for the previous layer? (so layer 1 loads layer 0 this way it's always after all bias are loaded?)
- check
is_hf_initializedas I think it should hold info about everything being initialized
TLDR let's avoid loops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, the shards that are loaded are not necessarily in order, so we cannot rely on it in general... And we check it only for leafs in the state dict (i.e. the Linear/Embedding/Norm layers), so they have at most 2 or 3 parameters(), so not much of an overhead I think. It does not look like we can use is_hf_initialized here (from what I understand it checks that the weights were created, not that the correct state_dict was loaded, and then dispatched to correct device)
In any way, if we did not specify tp_plan="auto", all of it is completely skipped
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it checks that weights were properly loaded normally! Because otherwise it goes through the init loop
| if buffer.device != tp_device: | ||
| buffer.data = buffer.to(tp_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, remember that no we pass the cos and sin as input to all layers so to are passed
| # The expected full model memory footprint | ||
| expected_model_memory = 16 | ||
| # The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add something related to this in the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it currently checks that we do not use more than the expected memory divided by world size , i.e., no more than 5 GiB per GPU on my tests on DGX for Llama 8B (expected memory per device = a bit more than 4 GiB)
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temporary solution IMO but much needed thanks let's merge
What does this PR do?
As per the title! At loading time, the parallelization is now applied module-by-module, so that no memory overhead is required compared to what the final weight distribution will be!