Skip to content

Commit 78b2929

Browse files
avishaiElmakiesAvishai Elmakies
andauthored
Sdpa dino v2 (#33403)
* add sdpa to dinov2 * fixup * add dinov2 to sdpa doc * update doc order * [run-slow] dinov2 * common to eager * [run-slow] dinov2 * update attn implementation in common * update test_modeling_dinov2 to have mask_ration, num_masks and mask_length similar to vit * [run-slow] dinov2 --------- Co-authored-by: Avishai Elmakies <[email protected]>
1 parent e71bf70 commit 78b2929

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

docs/source/en/perf_infer_gpu_one.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ For now, Transformers supports SDPA inference and training for the following arc
217217
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
218218
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
219219
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
220+
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
220221
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
221222
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
222223
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
@@ -275,7 +276,6 @@ For now, Transformers supports SDPA inference and training for the following arc
275276
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
276277
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
277278

278-
279279
<Tip>
280280

281281
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.

src/transformers/models/dinov2/modeling_dinov2.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,38 @@ def forward(
231231
return outputs
232232

233233

234+
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2
235+
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
236+
def __init__(self, config: Dinov2Config) -> None:
237+
super().__init__(config)
238+
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
239+
240+
def forward(
241+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
242+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
243+
mixed_query_layer = self.query(hidden_states)
244+
245+
key_layer = self.transpose_for_scores(self.key(hidden_states))
246+
value_layer = self.transpose_for_scores(self.value(hidden_states))
247+
query_layer = self.transpose_for_scores(mixed_query_layer)
248+
249+
context_layer = torch.nn.functional.scaled_dot_product_attention(
250+
query_layer,
251+
key_layer,
252+
value_layer,
253+
head_mask,
254+
self.attention_probs_dropout_prob if self.training else 0.0,
255+
is_causal=False,
256+
scale=None,
257+
)
258+
259+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
260+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
261+
context_layer = context_layer.view(new_context_layer_shape)
262+
263+
return context_layer, None
264+
265+
234266
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
235267
class Dinov2SelfOutput(nn.Module):
236268
"""
@@ -290,6 +322,13 @@ def forward(
290322
return outputs
291323

292324

325+
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2
326+
class Dinov2SdpaAttention(Dinov2Attention):
327+
def __init__(self, config: Dinov2Config) -> None:
328+
super().__init__(config)
329+
self.attention = Dinov2SdpaSelfAttention(config)
330+
331+
293332
class Dinov2LayerScale(nn.Module):
294333
def __init__(self, config) -> None:
295334
super().__init__()
@@ -371,14 +410,20 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
371410
return self.weights_out(hidden)
372411

373412

413+
DINOV2_ATTENTION_CLASSES = {
414+
"eager": Dinov2Attention,
415+
"sdpa": Dinov2SdpaAttention,
416+
}
417+
418+
374419
class Dinov2Layer(nn.Module):
375420
"""This corresponds to the Block class in the original implementation."""
376421

377422
def __init__(self, config: Dinov2Config) -> None:
378423
super().__init__()
379424

380425
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
381-
self.attention = Dinov2Attention(config)
426+
self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config)
382427
self.layer_scale1 = Dinov2LayerScale(config)
383428
self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
384429

@@ -485,6 +530,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
485530
main_input_name = "pixel_values"
486531
supports_gradient_checkpointing = True
487532
_no_split_modules = ["Dinov2SwiGLUFFN"]
533+
_supports_sdpa = True
488534

489535
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
490536
"""Initialize the weights"""

tests/models/dinov2/test_modeling_dinov2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(
6565
type_sequence_label_size=10,
6666
initializer_range=0.02,
6767
scope=None,
68+
attn_implementation="eager",
69+
mask_ratio=0.5,
6870
):
6971
self.parent = parent
7072
self.batch_size = batch_size
@@ -83,10 +85,14 @@ def __init__(
8385
self.type_sequence_label_size = type_sequence_label_size
8486
self.initializer_range = initializer_range
8587
self.scope = scope
88+
self.attn_implementation = attn_implementation
89+
self.mask_ratio = mask_ratio
8690

8791
# in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
8892
num_patches = (image_size // patch_size) ** 2
8993
self.seq_length = num_patches + 1
94+
self.num_masks = int(self.mask_ratio * self.seq_length)
95+
self.mask_length = num_patches
9096

9197
def prepare_config_and_inputs(self):
9298
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -113,6 +119,7 @@ def get_config(self):
113119
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
114120
is_decoder=False,
115121
initializer_range=self.initializer_range,
122+
attn_implementation=self.attn_implementation,
116123
)
117124

118125
def create_and_check_model(self, config, pixel_values, labels):

0 commit comments

Comments
 (0)