Skip to content

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Jan 31, 2025

What does this PR do?

As per the title! At loading time, the parallelization is now applied module-by-module, so that no memory overhead is required compared to what the final weight distribution will be!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice nice

Comment on lines +798 to +799
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should only do this for PreTrainedModels no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured that maybe it would be a bit more future-proof to iterate over all modules (it's not costly) -- but can be changed for sure!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's only do Prtrained for now!

Comment on lines +915 to +921
for param, plan in full_tp_plan.items():
# "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern
pattern = param.replace("*", "[0-9]+")
if re.search(pattern, parent_module_name):
current_module_plan = plan
break

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to iterate over the full tp_plan, but we should be re-creating the key instead

Copy link
Member Author

@Cyrilvallez Cyrilvallez Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tp_plan does not contain the full module names (usually it starts with "layers"), so to be general it's much easier to iterate over the keys instead of starting from the module name and trying to get the key of the tp_plan (because the prefixes of the tp_plan keys may change). Once again it's not costly at all since the tp_plan is very small

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm agreed cost-wise, it's a tad of a waste! but no worries

Comment on lines +924 to +927
process_device = list(device_map.values())[0]
all_module_parameters_initialized = all(
m.device == process_device for m in parent_module.parameters(recurse=False)
) and all(m.device == process_device for m in parent_module.buffers(recurse=False))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly this might be a tad bit costly for MOE for example / not necessarily needed.
We can either:

  • maybe load for the previous layer? (so layer 1 loads layer 0 this way it's always after all bias are loaded?)
  • check is_hf_initialized as I think it should hold info about everything being initialized
    TLDR let's avoid loops

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the shards that are loaded are not necessarily in order, so we cannot rely on it in general... And we check it only for leafs in the state dict (i.e. the Linear/Embedding/Norm layers), so they have at most 2 or 3 parameters(), so not much of an overhead I think. It does not look like we can use is_hf_initialized here (from what I understand it checks that the weights were created, not that the correct state_dict was loaded, and then dispatched to correct device)
In any way, if we did not specify tp_plan="auto", all of it is completely skipped

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it checks that weights were properly loaded normally! Because otherwise it goes through the init loop

Comment on lines +4402 to +4403
if buffer.device != tp_device:
buffer.data = buffer.to(tp_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, remember that no we pass the cos and sin as input to all layers so to are passed

# The expected full model memory footprint
expected_model_memory = 16
# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add something related to this in the test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it currently checks that we do not use more than the expected memory divided by world size , i.e., no more than 5 GiB per GPU on my tests on DGX for Llama 8B (expected memory per device = a bit more than 4 GiB)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary solution IMO but much needed thanks let's merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants