Skip to content

Commit 5c56bfc

Browse files
committed
Use public API instead of removed private function
* replaced use of _load_state_dict_into_model with model.load_state_dict because the private function _load_state_dict_into_model was removed in huggingface/transformers#36335 Signed-off-by: Jan Bielak <[email protected]>
1 parent 05f3b57 commit 5c56bfc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

docs/examples/te_llama/te_llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
LlamaRMSNorm,
2020
LlamaConfig,
2121
)
22-
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
22+
from transformers.modeling_utils import _add_variant, load_state_dict
2323
from transformers.utils import WEIGHTS_INDEX_NAME
2424
from transformers.utils.hub import get_checkpoint_shard_files
2525

@@ -148,8 +148,8 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
148148
state_dict = load_state_dict(shard_file)
149149
# replace_params copies parameters relevant only to TransformerEngine
150150
replace_params(state_dict, vanilla_model.state_dict(), config)
151-
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
152-
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
151+
# load_state_dict copies parameters other than those in TransformerEngine
152+
vanilla_model.load_state_dict(state_dict, strict=False)
153153

154154
# Force mem release. Taken from huggingface code
155155
del state_dict

0 commit comments

Comments
 (0)