Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Feb 15, 2025

What does this PR do?

(see title)

@gante gante requested a review from ArthurZucker February 15, 2025 10:24
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante requested review from zucchini-nlp and removed request for ArthurZucker February 15, 2025 12:12
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! I believe the v4.49 deprecation can also be removed since we released it, almost?

@gante gante merged commit dad513e into huggingface:main Feb 17, 2025
25 checks passed
@gante gante deleted the deprecations_v447_cache branch February 17, 2025 13:55
@BenjaminBossan
Copy link
Member

BenjaminBossan commented Feb 18, 2025

I'm running into an issue after this change. Here is a minimal reproducer:

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset

model_id = "hf-internal-testing/tiny-random-MistralForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

def process(samples):
    tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
    return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

trainer = Trainer(
    model=model,
    train_dataset=data["train"],
    args=TrainingArguments(
        num_train_epochs=1,
        max_steps=5,
        per_device_train_batch_size=4,
        output_dir="/tmp/mistral"
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

The error I get is:

Traceback (most recent call last):
  File "/home/name/work/forks/peft/foo.py", line 35, in <module>
    trainer.train()
  File "/home/name/work/forks/transformers/src/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/trainer.py", line 3698, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/trainer.py", line 3759, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 194, in forward
    return self.gather(outputs, self.output_device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 217, in gather
    return gather(outputs, output_device, dim=self.dim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py", line 135, in gather
    res = gather_map(outputs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py", line 127, in gather_map
    return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 8, in __init__
  File "/home/name/work/forks/transformers/src/transformers/utils/generic.py", line 392, in __post_init__
    for idx, element in enumerate(iterator):
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py", line 127, in <genexpr>
    return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py", line 130, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: DynamicCache.__init__() takes 1 positional argument but 2 were given

Re-adding the num_hidden_layers argument to DynamicCache.__init__() (without doing anything with that argument) resolves the error. Running on a single GPU (thus avoiding data parallel) also resolves the error.

I know that it has been deprecated, but given the reproducer, I'm not sure how I would even account for that. Is the checkpoint outdated or what else can I do?

@BenjaminBossan
Copy link
Member

ping @gante

@gante
Copy link
Contributor Author

gante commented Feb 20, 2025

Writing some findings:

When we compute the loss, we return a ModelOutput. ModelOutput has a __post_init__ method where it unpacks the contents of the first field into the other fields, when the first field is an iterator. This doesn’t happen with 1 GPU, but happens with 2+ GPUs — packing happens with torch parallel, for some reason I’m unaware of.

When we iterate over the fields while unpacking, i.e. for idx, element in enumerate(iterator), torch parallel replaces it by type(next iterator item)(args). For some reason, the two tensors containing the data of the caches is passed in these args 👀

Placing a breakpoint inside DynamicCache.__init__ we see that a <class 'map'> instance is passed, and if we do next() on that instance we obtain the tensor data.


At a first glance, and given my very limited knowledge of torch parallel frameworks, it seems like it was expecting the model to return a @dataclass with specific types or class requirements, and that DynamicCache is not respecting that.

The fact that it was working before was luck: the old argument in DynamicCache was absorbing the arg that was passed.

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Feb 21, 2025
@S1ro1
Copy link
Contributor

S1ro1 commented Feb 24, 2025

The issue is, Gather receives a list of ModelOutput classes, then it is treating each of these classes as a dict therefore creating a single dict (type ModelOutput) with values gathered together recursively. However, when using * operator on value of type DynamicCache that implements __iter__, this returns tuple of past keys and values (k, v).

In this scenario:

outputs = [DynamicCache(...), DynamicCache(...)]
y = list(zip(*outputs)) # PyTorch does this in Gather
# y = [(k1, v1), (k2, v2)]

Because PyTorch is casting each call of Gather to its original type (here DynamicCache) and calling it with the result of y, we get the original k, v that has been gathered passed into the constructor of DynamicCache resulting in the error above.

I'm not totally familiar with the API we want to pursue, but Gather on the cache currently basically doesn't work as the resulting Cache after applying Gather doesn't really make sense, as the keys/values aren't really gathered.

I do see only 1 option to fix this without touching torch and that being extra param in the __init__, such as this, depends if gathering the cache is what we'd like to do, if not, then just empty param and noop as before

    def __init__(self, cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> None:
        super().__init__()
        ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants