Skip to content

Conversation

@kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Feb 20, 2025

Purpose

  • Reduce runtime when loading missing keys with low_cpu_mem_usage
    • This is particularly helpful when loading large models

Changes

  • Do not eagerly convert all keys into a list. Instead, check membership directly to preserve O(1) lookup within the for loop

Testing

  • This script can be used to load a large model without a state dict for testing
    • Without these changes, loading deepseek_v3 takes 229.1s
    • With these changes, loading deepseek_v3 takes 64.5s (~4x speedup)
load_without_weights_for_testing.py
import os
import torch
import tempfile
import contextlib
from huggingface_hub import snapshot_download
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
from safetensors.torch import save_file

import time
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights

## Define utils

@contextlib.contextmanager
def skip_weights_download(model_class: PreTrainedModel = AutoModelForCausalLM):
    """
    Context manager under which models are initialized without having to download
    the model weight files

    :param model_class: class to patch, `AutoModelForCausalLM`
    """
    original_fn = model_class.from_pretrained
    weights_files = [
        "*.bin", "*.safetensors", "*.pth", SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
    ]

    @classmethod
    def patched(cls, *args, **kwargs):
        nonlocal tmp_dir

        # intercept model stub
        model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path")

        # download files into tmp dir
        os.makedirs(tmp_dir, exist_ok=True)
        snapshot_download(
            repo_id=model_stub,
            local_dir=tmp_dir,
            ignore_patterns=weights_files
        )

        # make an empty weights file to avoid errors
        weights_file_path = os.path.join(tmp_dir, "model.safetensors")
        save_file({}, weights_file_path, metadata={"format": "pt"})

        # load from tmp dir
        return original_fn(tmp_dir, **kwargs)
    
    with tempfile.TemporaryDirectory() as tmp_dir:
        model_class.from_pretrained = patched
        yield
        model_class.from_pretrained = original_fn


@contextlib.contextmanager
def skip_weights_initialize():
    def skip(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        return tensor

    kaiming_restore = torch.nn.init.kaiming_uniform_
    uniform_restore = torch.nn.init.uniform
    normal_restore = torch.nn.init.normal_

    t_uniform_restore = torch.Tensor.uniform_
    t_normal_restore = torch.Tensor.normal_

    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip

    torch.Tensor.uniform_ = skip
    torch.Tensor.normal_ = skip
    try:
        yield
    finally:
        torch.nn.init.kaiming_uniform_ = kaiming_restore
        torch.nn.init.uniform_ = uniform_restore
        torch.nn.init.normal_ = normal_restore

        torch.Tensor.uniform_ = t_uniform_restore
        torch.Tensor.normal_ = t_normal_restore


## Load model
model_name = "deepseek-ai/DeepSeek-V3"

# needed for deekseek
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
del config.quantization_config

# load model
start = time.time()
with skip_weights_download(), skip_weights_initialize():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        config=config,
        device_map="auto",
        trust_remote_code=True,
    )
    print(f"Loaded model in {time.time() - start:.1f}s")

Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs changed the title Reduce runtime when loading missing keys [Modeling] Reduce runtime when loading missing keys Feb 20, 2025
Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an obviously correct change, and a clear improvement!

Approving for now, but I see you're pushing additional commits, so ping me whenever it's ready for final review + merging.

@kylesayrs
Copy link
Contributor Author

@Rocketknight1 All good on my end, thanks!

@Rocketknight1 Rocketknight1 merged commit 05dfed0 into huggingface:main Feb 24, 2025
21 checks passed
@Rocketknight1
Copy link
Member

Merged, and thank you for the improvement @kylesayrs!

@kylesayrs kylesayrs deleted the kylesayrs/low-memory-load-optimization branch March 7, 2025 20:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants