Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
35e1a64
add sdpa to OPT
Aug 29, 2024
908e39b
chore: remove redundant whitespace in OPTDecoder class
Aug 29, 2024
c84a4dd
fixup
Sep 2, 2024
be32f92
bug fix
Sep 4, 2024
8063994
add sdpa and attention generate test
Sep 4, 2024
248029a
fixup
Sep 4, 2024
b66e3d8
Refactor OPTAttention forward method for improved readability and mai…
Sep 8, 2024
579d60e
undo refactor for _shape and key,val states
Sep 8, 2024
b105376
add OPT to doc, fixup didn't find it for some reason
Sep 10, 2024
c349632
change order
Sep 10, 2024
6dba8b0
change default attn_implemntation in testing to eager
Sep 11, 2024
989625b
Merge branch 'main' into spda_opt
avishaiElmakies Sep 11, 2024
1d21751
[run-slow] opt
Sep 16, 2024
7233fda
change test_eager_matches_sdpa_generate to the one llama
Sep 17, 2024
9bacdeb
Update default attention implementation in testing common
Sep 17, 2024
5b38f78
[run-slow] opt
Sep 17, 2024
3f24a04
remove uneeded print
Sep 17, 2024
2efd25a
[run-slow] opt
Sep 17, 2024
bdd9cb2
refactor model testers to have attn_implementation="eager"
Sep 18, 2024
f80e3b3
[run-slow] opt
Sep 18, 2024
7ea22eb
convert test_eager_matches_sdpa_generate to opt-350M
Sep 22, 2024
b5547e7
bug fix when creating mask for opt
Sep 22, 2024
eaa8028
Merge branch 'main' into spda_opt
Sep 22, 2024
668e291
[run-slow] opt
Sep 22, 2024
d9d3bb3
if layer head mask default to eager
Sep 22, 2024
388d663
if head mask is not none fall to eager
Sep 22, 2024
e735ec4
[run-slow] opt
Sep 22, 2024
f94d574
Update src/transformers/models/opt/modeling_opt.py
avishaiElmakies Sep 25, 2024
e734d9d
Clean up Unpack imports (#33631)
molbap Sep 23, 2024
34593ba
Fix DPT /Dinov2 sdpa regression on main (#33660)
molbap Sep 23, 2024
6889d69
handle dependency errors in check_imports (#33622)
molbap Sep 23, 2024
d488c33
add back self.max_position_embeddings = config.max_position_embedding…
chengchengpei Sep 23, 2024
9990915
Fix Llava conversion for LlavaQwen2ForCausalLM with Clip vision tower…
Isotr0py Sep 23, 2024
3720eca
Uniformize kwargs for Udop processor and update docs (#33628)
yonigozlan Sep 23, 2024
9b11d28
Generation: deprecate `PreTrainedModel` inheriting from `GenerationMi…
gante Sep 23, 2024
d3f8417
Enable BNB multi-backend support (#31098)
jiqing-feng Sep 24, 2024
52a0a75
Fix error string after refactoring into get_chat_template (#33652)
tibor-reiss Sep 24, 2024
400927e
uniformize git processor (#33668)
yonigozlan Sep 24, 2024
3b0d24c
Modular `transformers`: modularity and inheritance for new model addi…
ArthurZucker Sep 24, 2024
ef64c81
Fix CIs post merging modular transformers (#33681)
ArthurZucker Sep 24, 2024
6cd88aa
Fixed docstring for cohere model regarding unavailability of prune_he…
mnauf Sep 24, 2024
4a457c1
Generation tests: update imagegpt input name, remove unused functions…
gante Sep 24, 2024
4deac16
Improve Error Messaging for Flash Attention 2 on CPU (#33655)
sizhky Sep 24, 2024
1f7d50a
Gemma2: fix config initialization (`cache_implementation`) (#33684)
gante Sep 24, 2024
3e798fa
Fix ByteLevel alphabet missing when Sequence pretokenizer is used (#3…
umarbutler Sep 24, 2024
9665ecc
Uniformize kwargs for image-text-to-text processors (#32544)
yonigozlan Sep 25, 2024
e1839b9
🚨🚨 Setting default behavior of assisted decoding (#33657)
jmamou Sep 25, 2024
37da2d6
tests: fix pytorch tensor placement errors (#33485)
dvrogozh Sep 25, 2024
58c2b2b
bump tokenizers, fix added tokens fast (#32535)
ArthurZucker Sep 25, 2024
f0bb0a8
[Pixtral] Improve docs, rename model (#33491)
NielsRogge Sep 25, 2024
34a9142
fix code quality after merge
ArthurZucker Sep 25, 2024
6aeec65
HFQuantizer implementation for compressed-tensors library (#31704)
bfineran Sep 25, 2024
3e69375
Merge branch 'main' into spda_opt
avishaiElmakies Sep 25, 2024
a9b18dc
update model card for opt
Sep 25, 2024
9876dbb
add batch size to inference table
Sep 25, 2024
ff35bbc
[slow-run] opt
Sep 25, 2024
cfd1209
[run-slow] opt
Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
Expand Down
175 changes: 150 additions & 25 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -116,7 +119,7 @@ def __init__(
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
Expand Down Expand Up @@ -359,9 +362,106 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


class OPTSdpaAttention(OPTAttention):
"""
OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of sdpa
attention and deal with padding tokens in case the input contains any of them.
"""

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions or layer_head_mask is not None:
logger.warning_once("""OPTModel is using SDPA attention, which currently does not support output_attentions=True.
failing back to eager attention. remove warning using attn_implementation="eager".""")

return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
key_value_states=key_value_states,
) # TODO after merge add position_ids=position_ids
is_cross_attention = key_value_states is not None

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states) * self.scaling
query_states = self._shape(query_states, -1, bsz)

# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

# shape now is (bsz, num_heads, seq_len, head_dim), all are continuous

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
# this model uses the scaling factor in the query projection for some reason, but not in Q@K^T
# so we need to scale to remove scaling in SDPA to have similar results with eager.
# Maybe needs a change in the model to remove scaling in query projection
scale=1.0,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value


OPT_ATTENTION_CLASSES = {
"eager": OPTAttention,
"flash_attention_2": OptFlashAttention2,
"sdpa": OPTSdpaAttention,
}


Expand Down Expand Up @@ -488,6 +588,7 @@ class OPTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -604,6 +705,7 @@ def __init__(self, config: OPTConfig):

self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"

self.gradient_checkpointing = False
# Initialize weights and apply final processing
Expand All @@ -615,6 +717,49 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value

def _update_causal_mask(
self,
inputs_embeds: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
):
"""
Updates the causal mask for the decoder.
"""
batch_size, seq_length = input_shape
mask_seq_length = past_key_values_length + seq_length
if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)

return causal_attention_mask, attention_mask

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
if self._use_sdpa and not output_attentions and head_mask is None:
causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

return causal_attention_mask, attention_mask

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -696,32 +841,12 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

batch_size, seq_length = input_shape
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length

causal_attention_mask, attention_mask = self._update_causal_mask(
inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions
)
# embed positions
if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
else:
# 4d mask is passed through the layers
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

pos_embeds = self.embed_positions(attention_mask, past_key_values_length)

Expand Down
3 changes: 3 additions & 0 deletions tests/models/opt/test_modeling_flax_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
embed_dim=16,
word_embed_proj_dim=16,
initializer_range=0.02,
attn_implemetation="eager",
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -92,6 +93,7 @@ def __init__(
self.word_embed_proj_dim = word_embed_proj_dim
self.initializer_range = initializer_range
self.is_encoder_decoder = False
self.attn_implementation = attn_implemetation

def prepare_config_and_inputs(self):
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
Expand All @@ -114,6 +116,7 @@ def prepare_config_and_inputs(self):
word_embed_proj_dim=self.word_embed_proj_dim,
initializer_range=self.initializer_range,
use_cache=False,
attn_implementation=self.attn_implementation,
)
inputs_dict = prepare_opt_inputs_dict(config, input_ids)
return config, inputs_dict
Expand Down
74 changes: 73 additions & 1 deletion tests/models/opt/test_modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
import timeout_decorator # noqa

from transformers import OPTConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_torch_sdpa,
slow,
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -83,6 +90,7 @@ def __init__(
num_labels=3,
word_embed_proj_dim=16,
type_sequence_label_size=2,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -106,6 +114,7 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.word_embed_proj_dim = word_embed_proj_dim
self.is_encoder_decoder = False
self.attn_implementation = attn_implementation

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
Expand Down Expand Up @@ -135,6 +144,7 @@ def get_config(self):
embed_dim=self.embed_dim,
is_encoder_decoder=False,
word_embed_proj_dim=self.word_embed_proj_dim,
attn_implementation=self.attn_implementation,
)

def get_pipeline_config(self):
Expand Down Expand Up @@ -322,6 +332,68 @@ def test_opt_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30

tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350M")

texts = [
"hi here's a longer context, getting longer and",
"Hello this is a very long sentence my friend, very long for real",
"Today I am in Paris and",
]

model_sdpa = OPTForCausalLM.from_pretrained(
"facebook/opt-350M",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="sdpa",
).to(torch_device)

self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")

model_eager = OPTForCausalLM.from_pretrained(
"facebook/opt-350M",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)

self.assertTrue(model_eager.config._attn_implementation == "eager")

for _, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")

has_sdpa = False
for _, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")

for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)

res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)

@unittest.skip(reason="Does not work on the tiny model as we keep hitting edge cases.")
def test_model_parallelism(self):
super().test_model_parallelism()
Expand Down
3 changes: 3 additions & 0 deletions tests/models/opt/test_modeling_tf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
bos_token_id=0,
embed_dim=16,
word_embed_proj_dim=16,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -87,6 +88,7 @@ def __init__(
self.embed_dim = embed_dim
self.word_embed_proj_dim = word_embed_proj_dim
self.is_encoder_decoder = False
self.attn_implementation = attn_implementation

def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
Expand All @@ -108,6 +110,7 @@ def prepare_config_and_inputs_for_common(self):
embed_dim=self.embed_dim,
word_embed_proj_dim=self.word_embed_proj_dim,
is_encoder_decoder=False,
attn_implementation=self.attn_implementation,
**self.config_updates,
)
inputs_dict = prepare_opt_inputs_dict(config, input_ids)
Expand Down