Skip to content

Timm models Safetensor weights give 'NoneType' object has no attribute 'get', weight re-initialization and wrong num_labels #25282

@sawradip

Description

@sawradip

System Info

My env information:

- `transformers` version: 4.31.0
- Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.31
- Python version: 3.9.17
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.1
- Accelerate version: 0.20.3
- Accelerate config:    not found
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

For a GSOC project under Openvino Toolkit, I have working with Timm models through Transformers.

As we know most of the timm models(on HF Hub) are trained or fine-tuned on some variation of Imagenet dataset, and thus are effectively Image classification models. If I attempt to load Timm models using AutoModelForImageClassification,

import torch
from transformers import AutoModelForImageClassification

model_id = "timm/vit_tiny_r_s16_p8_224.augreg_in21k"

hf_model = AutoModelForImageClassification.from_pretrained( model_id)

out = hf_model(pixel_values = torch.zeros((5, 3, hf_model.config.image_size, hf_model.config.image_size)))
print(out.logits.shape)

I get this Error:

Traceback (most recent call last):
  File "/home/sawradip/Desktop/practice_code/practice_gsoc/optimum-intel/../demo.py", line 10, in <module>
    hf_model = AutoModelForImageClassification.from_pretrained( model_id,
  File "/home/sawradip/miniconda3/envs/gsoc_env/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 493, in from_pretrained
    return model_class.from_pretrained(
  File "/home/sawradip/miniconda3/envs/gsoc_env/lib/python3.9/site-packages/transformers/modeling_utils.py", line 2629, in from_pretrained
    state_dict = load_state_dict(resolved_archive_file)
  File "/home/sawradip/miniconda3/envs/gsoc_env/lib/python3.9/site-packages/transformers/modeling_utils.py", line 449, in load_state_dict
    if metadata.get("format") not in ["pt", "tf", "flax"]:
AttributeError: 'NoneType' object has no attribute 'get'

I find that this issue doesn't occur if I force transformers to use pytorch weights, and avoid .safetensors.

import torch
from transformers import AutoModelForImageClassification

model_id = "timm/vit_tiny_r_s16_p8_224.augreg_in21k"

hf_model = AutoModelForImageClassification.from_pretrained( model_id,
                                                            use_safetensors = False
                                                            )

out = hf_model(pixel_values = torch.zeros((5, 3, hf_model.config.image_size, hf_model.config.image_size)))
print(out.logits.shape)

But I still get this warnings in the output, that a lot of weights were not initialized successfully.

Some weights of ViTForImageClassification were not initialized from the model checkpoint at timm/vit_tiny_r_s16_p8_224.augreg_in21k and are newly initialized: ['encoder.layer.0.layernorm_before.bias', 'encoder.layer.11.attention.attention.query.weight', 'encoder.layer.1.attention.attention.query.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.layernorm_before.bias', 'encoder.layer.10.attention.attention.query.weight', 'encoder.layer.6.attention.attention.key.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.0.attention.attention.key.bias', 'encoder.layer.2.layernorm_after.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.10.layernorm_after.bias', 'layernorm.bias', 'encoder.layer.0.attention.attention.key.weight', 'encoder.layer.1.attention.attention.value.bias', 'encoder.layer.4.output.dense.weight', 'embeddings.patch_embeddings.projection.weight', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.1.layernorm_after.weight', 'encoder.layer.2.attention.attention.query.weight', 'encoder.layer.3.attention.attention.key.bias', 'encoder.layer.11.layernorm_after.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.2.layernorm_before.weight', 'encoder.layer.4.attention.attention.query.bias', 'encoder.layer.6.layernorm_after.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.7.layernorm_before.weight', 'encoder.layer.8.attention.attention.value.bias', 'encoder.layer.6.attention.attention.query.weight', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.10.layernorm_before.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.9.attention.attention.key.weight', 'encoder.layer.6.layernorm_after.bias', 'classifier.bias', 'encoder.layer.1.layernorm_before.bias', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.attention.query.bias', 'encoder.layer.3.layernorm_before.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.5.attention.attention.value.bias', 'encoder.layer.6.attention.attention.value.weight', 'encoder.layer.0.layernorm_after.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.7.layernorm_after.weight', 'encoder.layer.8.output.dense.bias', 'layernorm.weight', 'encoder.layer.0.output.dense.weight', 'encoder.layer.11.attention.attention.key.weight', 'encoder.layer.2.attention.attention.query.bias', 'encoder.layer.11.attention.attention.value.weight', 'encoder.layer.3.layernorm_after.bias', 'classifier.weight', 'encoder.layer.4.attention.attention.value.weight', 'encoder.layer.8.layernorm_after.weight', 'encoder.layer.9.attention.attention.query.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.1.attention.attention.value.weight', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.5.attention.attention.query.bias', 'encoder.layer.6.attention.attention.key.bias', 'encoder.layer.9.layernorm_before.bias', 'encoder.layer.7.attention.attention.query.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.8.layernorm_after.bias', 'encoder.layer.2.attention.attention.key.weight', 'encoder.layer.5.layernorm_after.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.7.layernorm_after.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.9.attention.attention.value.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.2.attention.attention.value.bias', 'encoder.layer.5.attention.attention.key.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.attention.attention.query.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.0.attention.attention.value.weight', 'encoder.layer.3.attention.attention.value.bias', 'encoder.layer.2.layernorm_before.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.1.output.dense.weight', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.5.attention.attention.value.weight', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.8.attention.attention.key.weight', 'encoder.layer.3.attention.attention.value.weight', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.7.attention.attention.key.weight', 'encoder.layer.0.attention.attention.value.bias', 'encoder.layer.2.attention.attention.value.weight', 'encoder.layer.5.layernorm_before.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.5.layernorm_before.weight', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.7.attention.attention.value.weight', 'encoder.layer.6.layernorm_before.weight', 'encoder.layer.3.attention.attention.key.weight', 'encoder.layer.11.attention.attention.query.bias', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.6.layernorm_before.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.10.attention.attention.value.weight', 'encoder.layer.7.attention.attention.key.bias', 'encoder.layer.10.attention.attention.value.bias', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.4.attention.attention.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.2.attention.attention.key.bias', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.8.attention.attention.query.weight', 'encoder.layer.3.attention.attention.query.bias', 'encoder.layer.1.attention.attention.key.weight', 'encoder.layer.4.layernorm_after.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.attention.attention.value.bias', 'encoder.layer.3.layernorm_before.weight', 'encoder.layer.11.attention.attention.key.bias', 'encoder.layer.10.output.dense.bias', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.3.output.dense.bias', 'encoder.layer.4.attention.attention.key.weight', 'encoder.layer.10.attention.attention.key.weight', 'encoder.layer.4.layernorm_before.weight', 'encoder.layer.9.attention.attention.value.weight', 'encoder.layer.5.attention.attention.query.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.0.attention.attention.query.weight', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.8.attention.attention.value.weight', 'encoder.layer.4.attention.attention.key.bias', 'encoder.layer.4.layernorm_after.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.0.layernorm_after.bias', 'encoder.layer.9.attention.attention.query.bias', 'encoder.layer.11.attention.attention.value.bias', 'encoder.layer.8.attention.attention.key.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.9.layernorm_after.bias', 'encoder.layer.11.layernorm_after.weight', 'encoder.layer.6.attention.attention.value.bias', 'encoder.layer.2.layernorm_after.bias', 'encoder.layer.9.layernorm_after.weight', 'encoder.layer.1.attention.attention.key.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.7.attention.attention.query.bias', 'embeddings.cls_token', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.11.layernorm_before.weight', 'encoder.layer.0.attention.attention.query.bias', 'encoder.layer.1.layernorm_after.bias', 'encoder.layer.3.attention.attention.query.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.10.layernorm_after.weight', 'encoder.layer.5.layernorm_after.weight', 'encoder.layer.1.layernorm_before.weight', 'encoder.layer.0.layernorm_before.weight', 'encoder.layer.5.attention.attention.key.bias', 'encoder.layer.8.layernorm_before.weight', 'encoder.layer.3.layernorm_after.weight', 'encoder.layer.10.layernorm_before.bias', 'embeddings.position_embeddings', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.7.layernorm_before.bias', 'encoder.layer.1.attention.attention.query.bias', 'encoder.layer.10.attention.attention.key.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.9.layernorm_before.weight', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.4.attention.attention.query.weight', 'encoder.layer.8.attention.attention.query.bias', 'encoder.layer.7.output.dense.bias', 'encoder.layer.8.layernorm_before.bias', 'encoder.layer.9.output.dense.bias', 'encoder.layer.8.attention.output.dense.bias', 'embeddings.patch_embeddings.projection.bias', 'encoder.layer.11.layernorm_before.bias', 'encoder.layer.9.attention.attention.key.bias']

Meaning this models directly can not be used for classification on imagenet.

But I still get a output the shape,(number of output classes: 2) which is not the expected number of class for this model

torch.Size([5, 2])

Whereas the model name timm/vit_tiny_r_s16_p8_224.augreg_in21k indicates that, the weights were fine-tuned for imagenet-21k, meaning classes 21843.

This happens because the attached model config files for all timm models in the hub, contains the number of output classes in num_classes parameter. Whereas AutoConfig expects the num_labels parameter from the config file, and not finding such an parameter, it assigns the default value 2, as can be seen here.

So we can see in the model,

print(hf_model.config.num_classes)
-> 21843
print(hf_model.config.num_labels)
->2

I know there are a number of issues, but it is not possible to reproduce the later ones without fixing the previous one. So creating separate issues for each one would be more cumbersome for the reader.

Let me summarize the points I am making:

  1. Can not load timm models through AutoModelForImageClassification due to loading from safetensors weight.
  2. If we mention explicitlyuse_safetensors = False , then the pytorch weights are loaded but Huge numbers of weights are initialized randomly.So the models won't be useful out of the box.
  3. For all models, number of output classes are 2, and unlike timm's create_model, there is no option for specifying num_classes by users without modifying the config file.

Is this behaviour expected?
@amyeroberts @rwightman

Expected behavior

Expected behavior is ,

This mentioned code block will output:

torch.Size([5, 21843])

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions