Skip to content

Add packed tensor format support for flex/sdpa/eager through the mask! #39194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 4, 2025

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Jul 3, 2025

What does this PR do?

As per the title.

import torch
from transformers import AutoModelForCausalLM
from transformers.masking_utils import create_causal_mask

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=torch.float16)


batch_size = 1
sequence_length = 10
cache_position = torch.arange(sequence_length)
position_ids = torch.tensor([[0,1,2,3,0,1,0,1,2,3]])  # This corresponds to 3 packed sequences

attention_mask = create_causal_mask(
    config=model.config,
    # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
    input_embeds=torch.empty((batch_size, sequence_length), dtype=model.dtype),
    attention_mask=None,
    cache_position=cache_position,
    past_key_values=None,
    position_ids=position_ids,
)
attention_mask

>>> tensor([[[[ True, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False],
          [False, False, False, False,  True, False, False, False, False, False],
          [False, False, False, False,  True,  True, False, False, False, False],
          [False, False, False, False, False, False,  True, False, False, False],
          [False, False, False, False, False, False,  True,  True, False, False],
          [False, False, False, False, False, False,  True,  True,  True, False],
          [False, False, False, False, False, False,  True,  True,  True,  True]]]])
Screenshot 2025-07-03 at 11 56 22

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, just missing a test 😉

Comment on lines 624 to 625
# Packed format is always on batch of size 1 so we can early exit if not the case
if not position_ids.shape[0] == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There really shouldn't be a restriction on this. It should work too with 2D packed tensors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm alright, I can lift it easily - thought it was always packed with all sequences along a batch of 1

Copy link
Contributor

github-actions bot commented Jul 3, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: arcee, aria, bitnet, cohere, cohere2, csm, deepseek_v3, dia, diffllama, dots1, emu3, gemma, gemma2, gemma3, gemma3n, glm

@Cyrilvallez Cyrilvallez added the for patch Tag issues / labels that should be included in the next patch label Jul 3, 2025
Copy link
Contributor

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect! fixes the regression and flex + packing works again for us now

@Cyrilvallez Cyrilvallez merged commit 0cf2791 into main Jul 4, 2025
27 checks passed
@Cyrilvallez Cyrilvallez deleted the packing-mask branch July 4, 2025 07:01
Cyrilvallez added a commit that referenced this pull request Jul 4, 2025
#39194)

* Add the necesary logic to mask_utils

* add it everywhere

* Update masking_utils.py

* style

* Update masking_utils.py

* Update modeling_mimi.py

* Update masking_utils.py

* add support for more than batch size 1

* Update masking_utils.py

* add test

* style

* Update test_masking_utils.py

* Update masking_utils.py

* add require_token

* fix tests

* fix
@BenjaminBossan
Copy link
Member

Hey @Cyrilvallez the docstring of create_masks_for_generate says the argument is optional:

    position_ids (`torch.Tensor`, optional)
       A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.

but it's actually a required argument, breaking existing code that calls this function. Is it intentional that it's required or should it have a default of None?

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 4, 2025
We use create_mask_for_generate from transformers. It was introduced in
v4.53.0 but in v4.53.1, the function signature was changed to include
position_ids as mandatory argument:

huggingface/transformers#39194

This breaks our function call in PEFT. This PR fixes the function call
by passing position_ids. This in turn would break the function call with
transformers v4.53.0, thus a strict version check is being used for >=
v4.53.1.

Moreover, the check has been moved inside the if-branch that actually
needs it instead of performing it at the start of the function. That
way, no error is raised if we don't visit this branch.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 4, 2025
We use create_mask_for_generate from transformers. It was introduced in
v4.53.0 but in v4.53.1, the function signature was changed to include
position_ids as mandatory argument:

huggingface/transformers#39194

This breaks our function call in PEFT. This PR fixes the function call
by passing position_ids. This in turn would break the function call with
transformers v4.53.0, thus a strict version check is being used for >=
v4.53.1.
@Snowdar
Copy link

Snowdar commented Jul 7, 2025

Hi, I would like to inquire whether you could implement the attention_mask with the pattern [1,1,2,2,2,3,3,3], and support packed tensors with FlashAttention for scenarios requiring a sparse mask.

This approach would enable us to leverage a universal method (2D attention mask/position IDs) to handle variable-length attention via masking. Additionally, we could extend support to 4D masks for more complex cases, building upon SDPA (Scaled Dot-Product Attention) and eager attention.

@Cyrilvallez
Copy link
Member Author

Hey @BenjaminBossan! It's optional in the sense that it can be None, but indeed I did not provide a default of None to force the models to pass the argument to always allow packed format (same as the past_key_values).
We could however rethink the default values maybe (e.g. it could make sense to allow cache_position to be None as well when the kv are as well, and construct them on the fly for external usage of the mask functions). Let me know your thoughts!

@Cyrilvallez
Copy link
Member Author

@Snowdar for your usage of FA2, you should not pass any mask mask but forward the seqlens directly 🤗

@BenjaminBossan
Copy link
Member

It's optional in the sense that it can be None, but indeed I did not provide a default of None to force the models to pass the argument to always allow packed format (same as the past_key_values).
We could however rethink the default values maybe (e.g. it could make sense to allow cache_position to be None as well when the kv are as well, and construct them on the fly for external usage of the mask functions). Let me know your thoughts!

Thanks for explaining. I think in this case, it would have been better to provide a default, given that the signature was changed in a backwards incompatible way and then the change was released as a patch release, where the expectation as a user is that I can always upgrade without fear of breakage. I'm not sure if this function is considered "private", but even so, I think providing a default when there is a reasonable one would have been better. Now that the patch release is out, it's too late so I don't have any strong opinion either way.

BenjaminBossan added a commit to huggingface/peft that referenced this pull request Jul 7, 2025
We use create_mask_for_generate from transformers. It was introduced in
v4.53.0 but in v4.53.1, the function signature was changed to include
position_ids as mandatory argument:

huggingface/transformers#39194

This breaks our function call in PEFT. This PR fixes the function call
by passing position_ids. This in turn would break the function call with
transformers v4.53.0, thus a strict version check is being used for >=
v4.53.1.
@ArthurZucker
Copy link
Collaborator

Yep I think we need to patch again to have a default @Cyrilvallez

efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
We use create_mask_for_generate from transformers. It was introduced in
v4.53.0 but in v4.53.1, the function signature was changed to include
position_ids as mandatory argument:

huggingface/transformers#39194

This breaks our function call in PEFT. This PR fixes the function call
by passing position_ids. This in turn would break the function call with
transformers v4.53.0, thus a strict version check is being used for >=
v4.53.1.
rjgleaton pushed a commit to rjgleaton/transformers that referenced this pull request Jul 17, 2025
huggingface#39194)

* Add the necesary logic to mask_utils

* add it everywhere

* Update masking_utils.py

* style

* Update masking_utils.py

* Update modeling_mimi.py

* Update masking_utils.py

* add support for more than batch size 1

* Update masking_utils.py

* add test

* style

* Update test_masking_utils.py

* Update masking_utils.py

* add require_token

* fix tests

* fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
for patch Tag issues / labels that should be included in the next patch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants