fix(conversion): Fix size mismatch error during TF->PT model loading #38014
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Loading a PyTorch model from a saved TensorFlow checkpoint using
from_pretrained(..., from_tf=True)could fail with aRuntimeError: size mismatch. The error indicated that weights likeposition embeddings were expected to have the shape of word embeddings
(e.g., [vocab_size, hidden_size]).
This issue was triggered by recent changes that defaulted to initializing
the PyTorch model with meta tensors (
init_empty_weights) during thisconversion process.
The root cause was in the tied weight handling logic within
load_tf2_state_dict_in_pytorch_modelinmodeling_tf_pytorch_utils.py.Multiple distinct parameters initialized as meta tensors can share the same
data_ptr() == 0. The existing logic incorrectly identified these as tiedweights and reused the tensor loaded for the first parameter encountered
with
data_ptr() == 0(often the word embeddings) for subsequent parametersthat also had
data_ptr() == 0.This fix modifies the tied weight check to explicitly skip cases where
pt_weight.data_ptr() == 0, preventing the incorrect reuse of tensorsfor distinct meta parameters and resolving the size mismatch error.
Includes a unit test in
test_modeling_utils.pyto specifically verifythis scenario using
from_pretrained(..., from_tf=True)with meta initialization.Fixes #37786
Who can review?
@Rocketknight1 @gante