Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Feb 27, 2025

What does this PR do?

This PR fixes a couple of issues seen from this PR. Here's a list:

  • We can't load bin file anymore if the model is on meta device -> not compatible with quantization also
  • Issue with fetching submodules (diffusers and peft issue)
  • allocation issue (cc @gante )
  • disk offload issue that we see in bnb tests
  • need to guard torch import Dtensor support requires torch>=2.5.1 #36472

Issues remaining for follow-up PRs

  • maybe check how we can deal with renamed keys better. It's a big mess right now. (cc @Cyrilvallez with your refactor PR).
  • deepspeed issue
  • probably more since lots of code were modified

To reproduce errors coming from peft CI :

from transformers import AutoModelForCausalLM
import torch
model_ids = [
    "facebook/opt-125m",
    "facebook/opt-350m",
    "facebook/opt-6.7b",
]
device_maps = [None, 0, "auto"]

for device_map in device_maps:
    for model_id in model_ids:
        try:
            model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map)
            print(model.model.decoder.embed_tokens.weight)
            print(f"Model {model_id} with device_map {device_map} loaded successfully")
        except AttributeError as e:
            print(f"Model {model_id} with device_map {device_map} failed to load with error: {e}")

allocation issue

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

DEVICE = "cuda"
MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated(DEVICE) * 1e-6
print("Before loading -- Max memory (MB): ", max_memory)
torch.cuda.reset_peak_memory_stats(DEVICE)


model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype=torch.float16)
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated(DEVICE) * 1e-6
print("After loading -- Max memory (MB): ", max_memory)
torch.cuda.reset_peak_memory_stats(DEVICE)


tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
_ = model(**inputs)
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated(DEVICE) * 1e-6
print("After forward -- Max memory (MB): ", max_memory)
torch.cuda.reset_peak_memory_stats(DEVICE)

@github-actions github-actions bot marked this pull request as draft February 27, 2025 14:28
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@SunMarc SunMarc changed the title Fix meta loading Fix couples of issues from #36335 Feb 27, 2025
@SunMarc SunMarc marked this pull request as ready for review February 27, 2025 14:32
@SunMarc SunMarc requested a review from ArthurZucker February 27, 2025 14:32
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Contributor

gante commented Feb 28, 2025

@SunMarc if possible, add tests to prevent regressions 🙏

@SunMarc
Copy link
Member Author

SunMarc commented Feb 28, 2025

@SunMarc if possible, add tests to prevent regressions 🙏

I think we had a lot of failing tests due to that PR actually, not sure how we missed them @ydshieh @muellerzr . But happy to add maybe more fast tests with this is what is missing.

@SunMarc
Copy link
Member Author

SunMarc commented Feb 28, 2025

failing tests are not related to this PR but I found out that is was also due to #36335. Need to fix

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@SunMarc
Copy link
Member Author

SunMarc commented Feb 28, 2025

I don't think the CI will pass so can you merge it @ArthurZucker ?

@ArthurZucker ArthurZucker merged commit a40f1ac into main Mar 1, 2025
20 of 24 checks passed
@ArthurZucker ArthurZucker deleted the fix-meta-loading branch March 1, 2025 06:12
garrett361 pushed a commit to garrett361/transformers that referenced this pull request Mar 4, 2025
* fix

* style

* better allocation

* fix

* fix

* style

* revert disk

* exit

* style

* return if nothing to cache

* dtensor guard

* fix regressiion

* fix regression

* fix

* fix
garrett361 pushed a commit to garrett361/transformers that referenced this pull request Mar 4, 2025
* fix

* style

* better allocation

* fix

* fix

* style

* revert disk

* exit

* style

* return if nothing to cache

* dtensor guard

* fix regressiion

* fix regression

* fix

* fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants