Skip to content

Custom nn.Parameter initialization in PreTrainedModel subclasses is overwritten by post_init()/from_pretrained() causing NaNs/Zeros #42418

@Noietch

Description

@Noietch

System Info

  • transformers version: 4.57.1
  • Platform: Linux-4.18.0-147.mt20200626.413.el8_1.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.6.2
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • DeepSpeed version: 0.18.2
  • PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@Cyrilvallez @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import numpy as np
import os
import random
import torch
import torch.nn as nn

from transformers import Qwen3VLForConditionalGeneration


def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ["PYTHONHASHSEED"] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(66)


class TestModel1(Qwen3VLForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.action_head = nn.Linear(1024, 7)
        self.positional_embedding = nn.Parameter(torch.randn(16, 1152))
        self.post_init()


class TestModel2(nn.Module):
    def __init__(self, *args, model_path, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = Qwen3VLForConditionalGeneration.from_pretrained(model_path)
        self.action_head = nn.Linear(1024, 7)
        self.positional_embedding = nn.Parameter(torch.randn(16, 1152))


test_model1 = TestModel1.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")
test_model2 = TestModel2(model_path="Qwen/Qwen3-VL-4B-Instruct")
print(test_model1.positional_embedding)
print(test_model1.positional_embedding.mean(), test_model1.positional_embedding.std())
print(test_model2.positional_embedding)
print(test_model2.positional_embedding.mean(), test_model2.positional_embedding.std())

Expected behavior

When subclassing a model (inheriting from PreTrainedModel, e.g., Qwen3VLForConditionalGeneration, LlamaForCausalLM) to add custom learnable parameters, user-defined initialization in init is often silently overwritten.

This occurs because from_pretrained (or the end of init) triggers self.post_init(), which recursively calls _init_weights. This mechanism re-initializes all parameters, ignoring the explicit initialization code provided by the user in init.

In the specific case of Qwen3-VL (and potentially others), this re-initialization results in NaNs or Zeros, rendering the model unusable without manual intervention.

Steps to reproduce The following script demonstrates the issue. Note: I used torch.randn for the custom parameter initialization. While I understand that torch.randn samples from a standard normal distribution and does not guarantee an exact sample mean of 0 and std of 1, it should result in valid float values. The observed NaNs/Zeros confirm that this initialization is being discarded and replaced by a faulty internal initialization logic.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Feature requestRequest for a new featureUsageGeneral questions about the librarybug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions