Skip to content

Conversation

@S1ro1
Copy link
Contributor

@S1ro1 S1ro1 commented Mar 7, 2025

What does this PR do?

Fixes #34928 (partially), the underlying issue also references many other with similar issue.

What's the issue

We have many checks that unset use_cache in-case of gradient checkpointing, however none that do so for activation checkpointing.

This becomes an issue because we do not really have a way to detect that the model uses AC (as to my best knowledge). That's why we can only partially mitigate this when used together with Accelerate, because that sets FSDP_ACTIVATION_CHECKPOINTING flag if AC is used. To my best knowledge, other ways to apply AC do not create any side effects that enable us to detect it.

Why in PretrainedModel.__init__ and not lower?

Doesn't really matter if model is created via from_pretrained, from_config or with its constructor, this issue affects all of it. Also it affects any model that uses DynamicCache so basically any generation model, can't warn everywhere.

Minimal Repro

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset


def train():
    model_name = "qwen/Qwen2.5-1.5B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_name)
    eli5 = load_dataset("eli5_category", split="train[:100]")
    eli5 = eli5.train_test_split(test_size=0.2)

    def preprocess_function(examples):
        return tokenizer(
            [" ".join(x) for x in [ex["text"] for ex in examples["answers"]]]
        )

    tokenized_eli5 = eli5.map(
        preprocess_function,
        batched=True,
        num_proc=4,
        remove_columns=eli5["train"].column_names,
    )
    block_size = 128

    def group_texts(examples):
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # model.config.use_cache = False
    training_args = TrainingArguments(
        learning_rate=2e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=3,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=lm_dataset["train"],
        eval_dataset=lm_dataset["test"],
        data_collator=data_collator,
        processing_class=tokenizer,
    )

    trainer.train()


if __name__ == "__main__":
    train()
$ accelerate launch --use-fsdp --fsdp-activation-checkpointing=true main.py

This would fail before on:

saved metadata: ...
recomputed metadata: ...

As Qwen is one of the models that sets use_cache=true

cc @muellerzr

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@github-actions github-actions bot marked this pull request as draft March 7, 2025 18:11
@github-actions
Copy link
Contributor

github-actions bot commented Mar 7, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@S1ro1 S1ro1 marked this pull request as ready for review March 7, 2025 18:11
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@S1ro1
Copy link
Contributor Author

S1ro1 commented Mar 8, 2025

Note: failing test seems to be unrelated and failing else-where as well.
Note2: figuring out a way to detect AC even outside of Accelerate would be nice, as this seems to be an often issue

@Rocketknight1
Copy link
Member

This looks like an interaction between generation caching and activation checkpointing so cc @muellerzr @gante

@S1ro1
Copy link
Contributor Author

S1ro1 commented Mar 10, 2025

Note2: figuring out a way to detect AC even outside of Accelerate would be nice, as this seems to be an often issue

From PyTorch folks: you can inspect the module structure if someone uses the module wrapper version of AC CheckpointWrapper but otherwise (e.g. if user just modifies their forward code) there's indeed no easy way to know AC will be applied without actually running the code

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge @muellerzr if it works for you!

@ArthurZucker ArthurZucker requested review from Rocketknight1 and muellerzr and removed request for Rocketknight1 March 20, 2025 10:22
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed :)

@muellerzr
Copy link
Contributor

@S1ro1 can you resolve the conflicts and then @ArthurZucker can likely merge (since unrelated failures)

@S1ro1 S1ro1 force-pushed the fix/ac_use_cache branch from 7cab772 to 0ebf27c Compare March 20, 2025 21:43
@S1ro1
Copy link
Contributor Author

S1ro1 commented Mar 20, 2025

@ArthurZucker Should be good to merge.

EDIT: Failing use_cache test right now, weirdly enough it passed before, meaning AC with use_cache works (at least sometimes). Gotta check if some of the commits after this initial draft fixed the underlying issue.

Until then let's not merge.

@S1ro1
Copy link
Contributor Author

S1ro1 commented Mar 20, 2025

The underlying issue with cache and activation checkpointing has been fixed as an side-effect of #36311, therefore no need for merge. Closing now.

@S1ro1 S1ro1 closed this Mar 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Recomputed tensor size does not match when using activation checkpointing when using FSDP and accelerate

5 participants