Skip to content

Conversation

@remi-or
Copy link
Collaborator

@remi-or remi-or commented Oct 14, 2025

Currently, DDP is broken when there is a DynamicCache because it has no __iter___ method and so it cannot be concatenated after the distributed forward. This PR adds back and __iter__ and adapts the way ddp data is consumed to properly initialize sliding windows.

@remi-or remi-or requested a review from Cyrilvallez October 14, 2025 10:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! indeed this PR broke a few things regarding distributed training. I have the same issue with DP. cc @gante as I saw in previous PR that you were the one who fixed the compatibility with DP.
Btw, there is a failing test @remi-or that needs to be fixed

@SunMarc SunMarc requested a review from gante October 14, 2025 13:11
@gante
Copy link
Contributor

gante commented Oct 14, 2025

indeed this PR broke a few things regarding distributed training.

I remember fixing it (for similar reasons as in this PR), but I don't remember the following: can we instead disable cache at training time? I know there are some less-common fine-tuning strategies that use caches (prefix tuning in PEFT), but it may be wiser to simply disallow cache+DDP

At inference time, if we want scale, we won't be using these classes anyways (we want continuous batching)

@SunMarc
Copy link
Member

SunMarc commented Oct 14, 2025

but it may be wiser to simply disallow cache+DDP

Indeed, I think for v5 it is reasonable to do this change and most users won't be impacted by this change ! cc @BenjaminBossan

@BenjaminBossan
Copy link
Member

can we instead disable cache at training time?

What exactly would that entail? Would it mean that any PEFT method that uses past_key_values during training would stop working?

@remi-or remi-or merged commit 82cae9e into huggingface:main Oct 14, 2025
25 checks passed
@SunMarc
Copy link
Member

SunMarc commented Oct 14, 2025

What exactly would that entail? Would it mean that any PEFT method that uses past_key_values during training would stop working?

Most likely, I will add an arg in trainer to allow users to change use_cache and the default will be False unless we detect it is a peft model.

@SunMarc
Copy link
Member

SunMarc commented Oct 14, 2025

lmk what you think of that : #41585

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Oct 15, 2025

Yes, I strongly agree with @gante here - I don't really see the point of supporting dp/ddp on the cache, as it's inference-only anyway. If we can avoid having those somewhat awkward __init__, it's a win in my opinion!

@BenjaminBossan
Copy link
Member

Oh, I just noticed that this PR breaks this line in PEFT:

https://github.com/huggingface/peft/blob/086f187a4d89be3fbfcc3629fd833c0878dde4c0/src/peft/peft_model.py#L796

There, we initialize a DynamicCache instance with ddp_cache_data but now the expected input data format has changed. Would it be a possibility to make passing sliding_window_tensor optional in the data format? Otherwise, I guess we'd have to do a version check on transformers in PEFT and then insert a None for sliding_window_tensor into the ddp_cache_data. But other code out there could possibly also be affected by this change, so a backwards compatible init would be better IMHO.

@remi-or
Copy link
Collaborator Author

remi-or commented Oct 15, 2025

Hey @BenjaminBossan , I am adding support for (key, value) tuples instead of (sliding, key, value) in #41612 . Once merged, please let us know if it works!

ngazagna-qc pushed a commit to ngazagna-qc/transformers that referenced this pull request Oct 23, 2025
* Add __iter__ to DynamicCache

* Fix tests that use ddp init
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants