Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 82 additions & 17 deletions src/transformers/models/pixtral/modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""PyTorch Pixtral model."""

from collections.abc import Callable
from typing import Optional, Tuple, Union

import torch
Expand All @@ -22,13 +23,11 @@

from ... import PreTrainedModel
from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput
from ...modeling_rope_utils import dynamic_rope_update
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_pixtral import PixtralVisionConfig


Expand Down Expand Up @@ -132,8 +131,37 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed


def _eager_bidir_attention(
module: nn.Module,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor],
scaling: float,
dropout: float,
**kwargs,
):
"""
plain-PyTorch attention.
Accepts the same signature used by ALL_ATTENTION_FUNCTIONS.
"""
# (B, H, L, D) @ (B, H, D, L) ➜ (B, H, L, L)
scores = torch.matmul(q, k.transpose(-2, -1)) * scaling
if attn_mask is not None:
scores = scores + attn_mask

attn = torch.softmax(scores, dim=-1, dtype=torch.float32).to(q.dtype)
attn = nn.functional.dropout(attn, p=dropout, training=module.training)

# (B, H, L, L) @ (B, H, L, D) ➜ (B, H, L, D)
out = torch.matmul(attn, v).transpose(1, 2).contiguous() # (B, L, H, D)
return out, attn


class PixtralAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""
Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS.
"""

def __init__(self, config):
super().__init__()
Expand All @@ -142,6 +170,8 @@ def __init__(self, config):
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads

self.is_causal = False

self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout

Expand All @@ -155,7 +185,9 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**flash_kwargs: FlashAttentionKwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

Expand All @@ -172,17 +204,36 @@ def forward(
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale

if attention_mask is not None:
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
impl = getattr(self.config, "_attn_implementation", "eager")

if impl == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
impl = "eager"

# Since we use packing, if Flash-Attn 2 is selected we rely on position_ids
if impl == "flash_attention_2":
position_ids = position_ids.to(hidden_states.device, non_blocking=True)
attention_mask = None
flash_kwargs["position_ids"] = position_ids

attn_fn: Callable = _eager_bidir_attention if impl == "eager" else ALL_ATTENTION_FUNCTIONS[impl]

attn_output, attn_weights = attn_fn(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
is_causal=self.is_causal,
output_attentions=output_attentions,
**flash_kwargs,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, patches, -1)

attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -241,6 +292,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
) -> Tuple[torch.FloatTensor]:
"""
Expand All @@ -260,6 +312,7 @@ def forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -290,6 +343,7 @@ def forward(
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand Down Expand Up @@ -333,13 +387,15 @@ def forward(
hidden_states,
attention_mask,
position_embeddings,
position_ids,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -381,6 +437,9 @@ class PixtralPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["PixtralAttentionLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down Expand Up @@ -457,7 +516,7 @@ def get_input_embeddings(self):
def forward(
self,
pixel_values: torch.Tensor,
image_sizes: torch.Tensor,
image_sizes: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand All @@ -469,6 +528,11 @@ def forward(
pixel_values: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
if image_sizes is None:
batch_size, _, height, width = pixel_values.shape
# on crée une liste de tuples (H, W) répétée batch_size fois
image_sizes = [(height, width)] * batch_size

# pass images through initial convolution independently
patch_embeds = self.patch_conv(pixel_values)
patch_embeds_list = [
Expand All @@ -494,6 +558,7 @@ def forward(
patch_embeds,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
Expand Down