Skip to content
11 changes: 7 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,9 +799,10 @@ def _load_state_dict_into_meta_model(
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))

is_safetensors_shard = shard_file.endswith(".safetensors")
file_pointer = None
bin_state_dict = None
if shard_file.endswith(".safetensors"):
if is_safetensors_shard:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
else:
map_location = "cpu"
Expand Down Expand Up @@ -831,7 +832,7 @@ def _load_state_dict_into_meta_model(
# we need to use serialized_param_name as file pointer is untouched
param = (
file_pointer.get_slice(serialized_param_name)
if shard_file.endswith(".safetensors")
if is_safetensors_shard
else bin_state_dict[serialized_param_name]
)

Expand Down Expand Up @@ -900,13 +901,15 @@ def _load_state_dict_into_meta_model(
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
param = param[:]
if is_safetensors_shard:
param = param[:]
if old_param is not None and old_param.is_contiguous():
param = param.contiguous()
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)

else:
param = param[:]
if is_safetensors_shard:
param = param[:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also can't slice 0-dim tensors. See CI from diffusers with MusicLDM and AudioLDM2 which use ClapModel from transformers. https://github.com/huggingface/diffusers/actions/runs/13714387592/job/38359634948?pr=10997

model = ClapModel(
  (text_model): ClapTextModel(
    (embeddings): ClapTextEmbeddings(
      (word_embeddings): Embedding(100...eatures=16, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=16, out_features=16, bias=True)
  )
)
state_dict = {'audio_model.audio_encoder.batch_norm.bias': tensor(..., device='meta', size=(8,)), 'audio_model.audio_encoder.batch_...ice='meta', size=(8,)), 'audio_model.audio_encoder.batch_norm.running_var': tensor(..., device='meta', size=(8,)), ...}
...
>               param = param[:]
E               IndexError: slice() cannot be applied to a 0-dim tensor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you launch the tests if this specific pr ?
For safetensors file, param is a PySlice object. This is why we need to do param = param[:]
param = file_pointer.get_slice(serialized_param_name)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay it's the same as here #36372 (comment)

if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param is not None and old_param.is_contiguous():
Expand Down