Skip to content

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Feb 24, 2025

What does this PR do?

As per the title. There is a weird issue when running the following:

import torch
from transformers import AutoModelForCausalLM, LlamaConfig

config = LlamaConfig()
config.num_hidden_layers = 2
model = AutoModelForCausalLM.from_config(config).to(0)


model.compile(fullgraph=True)

with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 200), device=0)
    out = model(input_ids)

# Change shape of input
with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 100), device=0)
    out = model(input_ids)

It fails with

TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
...
RuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(1, 32, s0, 128)), FakeTensor(..., device='cuda:0', size=(1, 32, s0, 128)), FakeTensor(..., device='cuda:0', size=(1, 32, s0, 128))), **{'attn_mask': None, 'dropout_p': 0.0, 'scale': 0.08838834764831845, 'is_causal': s0 > 1}):
scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool

i.e. it somehow traces the is_causal assignment as a SymBool instead of evaluating it.

If we pass a mask every time it works, however, as soon as we remove it (which should be equivalent in term of logic) it fails again:

import torch
from transformers import AutoModelForCausalLM, LlamaConfig

config = LlamaConfig()
config.num_hidden_layers = 2
model = AutoModelForCausalLM.from_config(config).to(0)


model.compile(fullgraph=True)

with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 200), device=0)
    out = model(input_ids, attention_mask=torch.ones_like(input_ids))

# Change shape of input but keep a mask
with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 100), device=0)
    out = model(input_ids, attention_mask=torch.ones_like(input_ids))

# Change shape of input but remove mask -> it fails here
with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 100), device=0)
    out = model(input_ids)

Simply switching the order of the check in sdpa_attention fixes the issue, and ensures that it works independently of passing a mask or not. Not entirely sure why, but it looks like dynamo somehow does not treat correctly the shape check if it is done afterwards.

With the current fix, both of the following scenarios work correctly:

Suddenly removing the mask when changing shape:

import torch
from transformers import AutoModelForCausalLM, LlamaConfig

config = LlamaConfig()
config.num_hidden_layers = 2
model = AutoModelForCausalLM.from_config(config).to(0)


model.compile(fullgraph=True)

with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 200), device=0)
    out = model(input_ids, attention_mask=torch.ones_like(input_ids))

with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 100), device=0)
    out = model(input_ids, attention_mask=torch.ones_like(input_ids))

# Stop passing the mask
with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 100), device=0)
    out = model(input_ids)

Adding a mask when changing shape:

import torch
from transformers import AutoModelForCausalLM, LlamaConfig

config = LlamaConfig()
config.num_hidden_layers = 2
model = AutoModelForCausalLM.from_config(config).to(0)


model.compile(fullgraph=True)

with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 200), device=0)
    out = model(input_ids)

# Change shape of input
with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 100), device=0)
    out = model(input_ids)

# Add the mask to the inputs
with torch.no_grad():
    input_ids = torch.randint(0, 100, (1, 100), device=0)
    out = model(input_ids, attention_mask=torch.ones_like(input_ids))

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 🤗

@Cyrilvallez Cyrilvallez merged commit 401543a into main Feb 25, 2025
24 checks passed
@Cyrilvallez Cyrilvallez deleted the fix-compile branch February 25, 2025 09:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants