Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,7 +1897,11 @@ def post_init(self):
if self.base_model is self:
self._pp_plan = self.config.base_model_pp_plan

self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {}
if not self._tp_plan:
if isinstance(self.config.base_model_tp_plan, dict):
self._tp_plan = self.config.base_model_tp_plan.copy()
else:
self._tp_plan = {}
Copy link
Member

@Cyrilvallez Cyrilvallez Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads-up that if a model which is above the base model in the class hierarchy does not have a tp_plan, we are going to wrongly add again the base_model_plan.
It's not the case as of now in the library though. But maybe

if self.base_model is self:
    self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
    self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
else:
    self._tp_plan = self._tp_plan or {}
    for name, module in self.named_children():
        if plan := getattr(module, "_tp_plan", None):
            self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})

is more robust and should still work with composite models

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry i just edited haha - the part with pp is less important as we don't modify the dict later

for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
Expand Down
1 change: 1 addition & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,6 +2305,7 @@ def test_generate_methods_with_logits_to_keep(self):
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())

@pytest.mark.generate
@is_flaky
def test_assisted_decoding_with_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def test_model_init(self):
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)

@slow
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42)
import datasets
Expand Down
Loading