Skip to content

Commit 9fe6dea

Browse files
qubvelkoustuvsinhayonigozlanpcuenca
authored andcommitted
Add V-JEPA 2 (huggingface#38746)
* adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <[email protected]> * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests --------- Co-authored-by: Koustuv Sinha <[email protected]> Co-authored-by: yonigozlan <[email protected]> Co-authored-by: Yoni Gozlan <[email protected]> Co-authored-by: Koustuv Sinha <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent be58a8b commit 9fe6dea

File tree

15 files changed

+1919
-0
lines changed

15 files changed

+1919
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,8 @@
905905
- sections:
906906
- local: model_doc/timesformer
907907
title: TimeSformer
908+
- local: model_doc/vjepa2
909+
title: V-JEPA 2
908910
- local: model_doc/videomae
909911
title: VideoMAE
910912
- local: model_doc/vivit

docs/source/en/model_doc/vjepa2.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
18+
<div style="float: right;">
19+
<div class="flex flex-wrap space-x-1">
20+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
23+
</div>
24+
</div>
25+
26+
# V-JEPA 2
27+
28+
V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, V-JEPA 2 attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.
29+
30+
<div class="flex justify-center">
31+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vjepa.gif" alt="drawing" width="600"/>
32+
</div>
33+
34+
You can find all original V-JEPA2 checkpoints under the [V-JEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) collection.
35+
36+
37+
This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yonigozlan](https://huggingface.co/yonigozlan) and [qubvel](https://huggingface.co/qubvel-hf). The original code can be found [here](https://github.com/facebookresearch/vjepa2).
38+
39+
## Usage example
40+
41+
The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class.
42+
43+
```py
44+
import torch
45+
from torchcodec.decoders import VideoDecoder
46+
import numpy as np
47+
48+
processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
49+
model = AutoModel.from_pretrained(
50+
"facebook/vjepa2-vitl-fpc64-256",
51+
torch_dtype=torch.float16,
52+
device_map="auto",
53+
attn_implementation="sdpa"
54+
)
55+
56+
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
57+
58+
vr = VideoDecoder(video_url)
59+
frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy
60+
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
61+
video = processor(video, return_tensors="pt").to(model.device)
62+
outputs = model(**video)
63+
64+
# V-JEPA 2 encoder outputs, same as calling `model.get_vision_features()`
65+
encoder_outputs = outputs.last_hidden_state
66+
67+
# V-JEPA 2 predictor outputs
68+
predictor_outputs = outputs.predictor_output.last_hidden_state
69+
```
70+
71+
## VJEPA2Config
72+
73+
[[autodoc]] VJEPA2Config
74+
75+
## VJEPA2Model
76+
77+
[[autodoc]] VJEPA2Model
78+
- forward
79+
80+
## VJEPA2VideoProcessor
81+
82+
[[autodoc]] VJEPA2VideoProcessor

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@
323323
from .vitpose_backbone import *
324324
from .vits import *
325325
from .vivit import *
326+
from .vjepa2 import *
326327
from .wav2vec2 import *
327328
from .wav2vec2_bert import *
328329
from .wav2vec2_conformer import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@
365365
("vitpose_backbone", "VitPoseBackboneConfig"),
366366
("vits", "VitsConfig"),
367367
("vivit", "VivitConfig"),
368+
("vjepa2", "VJEPA2Config"),
368369
("wav2vec2", "Wav2Vec2Config"),
369370
("wav2vec2-bert", "Wav2Vec2BertConfig"),
370371
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
@@ -750,6 +751,7 @@
750751
("vitpose_backbone", "ViTPoseBackbone"),
751752
("vits", "VITS"),
752753
("vivit", "ViViT"),
754+
("vjepa2", "VJEPA2Model"),
753755
("wav2vec2", "Wav2Vec2"),
754756
("wav2vec2-bert", "Wav2Vec2-BERT"),
755757
("wav2vec2-conformer", "Wav2Vec2-Conformer"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@
336336
("vitdet", "VitDetModel"),
337337
("vits", "VitsModel"),
338338
("vivit", "VivitModel"),
339+
("vjepa2", "VJEPA2Model"),
339340
("wav2vec2", "Wav2Vec2Model"),
340341
("wav2vec2-bert", "Wav2Vec2BertModel"),
341342
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),

src/transformers/models/auto/video_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
("qwen2_vl", "Qwen2VLVideoProcessor"),
5757
("smolvlm", "SmolVLMVideoProcessor"),
5858
("video_llava", "VideoLlavaVideoProcessor"),
59+
("vjepa2", "VJEPA2VideoProcessor"),
5960
]
6061
)
6162

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import TYPE_CHECKING
16+
17+
from ...utils import _LazyModule
18+
from ...utils.import_utils import define_import_structure
19+
20+
21+
if TYPE_CHECKING:
22+
from .configuration_vjepa2 import *
23+
from .modeling_vjepa2 import *
24+
from .video_processing_vjepa2 import *
25+
else:
26+
import sys
27+
28+
_file = globals()["__file__"]
29+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""VJEPA 2 model configuration"""
16+
17+
from ...configuration_utils import PretrainedConfig
18+
19+
20+
class VJEPA2Config(PretrainedConfig):
21+
r"""
22+
This is the configuration class to store the configuration of a [`VJEPA2Model`]. It is used to instantiate an
23+
VJEPA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
24+
with the defaults will yield a similar configuration to that of the VJEPA2
25+
[facebook/vjepa2-vitl-fpc64-256](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256) architecture.
26+
27+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
28+
documentation from [`PretrainedConfig`] for more information.
29+
30+
Args:
31+
patch_size (`int`, *optional*, defaults to 16):
32+
The size (resolution) of each patch.
33+
crop_size (`int`, *optional*, defaults to 256):
34+
Input resolution of the model
35+
frames_per_clip (`int`, *optional*, defaults to 64):
36+
The number of frames the model has been pretrained with. Does not impact inference.
37+
tubelet_size (`int`, *optional*, defaults to 2):
38+
The number of temporal frames used for a single rastor, check paper for more information.
39+
hidden_size (`int`, *optional*, defaults to 1024):
40+
Dimensionality of the encoder layers
41+
in_chans (`int`, *optional*, defaults to 3):
42+
The number of input channels
43+
num_attention_heads (`int`, *optional*, defaults to 16):
44+
Number of attention heads for each attention layer in the Encoder
45+
num_hidden_layers (`int`, *optional*, defaults to 24):
46+
The number of hidden layers
47+
drop_path_rate (`float`, *optional*, defaults to 0.0):
48+
Stochastic depth rate per sample (when applied in the main path of residual layers).
49+
mlp_ratio (`float`, *optional*, defaults to 4.0):
50+
Ratio of the hidden size of the MLPs used in Encoder relative to the `hidden_size`.
51+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
52+
The epsilon used by the layer normalization layers.
53+
qkv_bias (`bool`, *optional*, defaults to `True`):
54+
Whether to add a bias to the queries, keys and values.
55+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
56+
The dropout probability for attentions.
57+
The dropout probability for all fully connected layers.
58+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
59+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
60+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
61+
initializer_range (`float`, *optional*, defaults to 0.02):
62+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63+
pred_hidden_size (`int`, *optional*, defaults to 384):
64+
Dimensionality of the predictor layers
65+
pred_num_attention_heads (`int`, *optional*, defaults to 12):
66+
Number of attention heads for each attention layer in the Predictor
67+
pred_num_hidden_layers (`int`, *optional*, defaults to 12):
68+
Number of hidden layers in the Predictor
69+
pred_num_mask_tokens (`int`, *optional*, defaults to 10):
70+
Define the number of mask tokens to use in the Predictor
71+
pred_zero_init_mask_tokens (`bool`, *optional*, defaults to `True`):
72+
Initialize the mask tokens in the predictor with 0.
73+
pred_mlp_ratio (`float`, *optional*, defaults to 4.0):
74+
Ratio of the hidden size of the MLPs used in Predictor relative to the `pred_hidden_size`.
75+
76+
Example:
77+
78+
```python
79+
>>> from transformers import VJEPA2Config, VJEPA2Model
80+
81+
>>> # Initializing a VJEPA2 vjepa2-vitl-fpc64-256 style configuration
82+
>>> configuration = VJEPA2Config()
83+
84+
>>> # Initializing a model (with random weights) from the vjepa2-vitl-fpc64-256 style configuration
85+
>>> model = VJEPA2Model(configuration)
86+
87+
>>> # Accessing the model configuration
88+
>>> configuration = model.config
89+
```"""
90+
91+
model_type = "vjepa2"
92+
93+
def __init__(
94+
self,
95+
patch_size=16,
96+
crop_size=256,
97+
frames_per_clip=64,
98+
tubelet_size=2,
99+
hidden_size=1024,
100+
in_chans=3,
101+
num_attention_heads=16,
102+
num_hidden_layers=24,
103+
drop_path_rate=0.0,
104+
mlp_ratio=4.0,
105+
layer_norm_eps=1e-6,
106+
qkv_bias=True,
107+
attention_probs_dropout_prob=0.0,
108+
hidden_act="gelu",
109+
initializer_range=0.02,
110+
# predictor params
111+
pred_hidden_size=384,
112+
pred_num_attention_heads=12,
113+
pred_num_hidden_layers=12,
114+
pred_num_mask_tokens=10,
115+
pred_zero_init_mask_tokens=True,
116+
pred_mlp_ratio=4.0,
117+
**kwargs,
118+
):
119+
super().__init__(**kwargs)
120+
121+
self.crop_size = crop_size
122+
self.frames_per_clip = frames_per_clip
123+
self.patch_size = patch_size
124+
self.tubelet_size = tubelet_size
125+
self.hidden_size = hidden_size
126+
self.in_chans = in_chans
127+
self.num_attention_heads = num_attention_heads
128+
self.num_hidden_layers = num_hidden_layers
129+
self.drop_path_rate = drop_path_rate
130+
self.mlp_ratio = mlp_ratio
131+
self.layer_norm_eps = layer_norm_eps
132+
self.qkv_bias = qkv_bias
133+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
134+
self.hidden_act = hidden_act
135+
self.initializer_range = initializer_range
136+
self.image_size = crop_size
137+
# predictor params
138+
self.pred_hidden_size = pred_hidden_size
139+
self.pred_num_attention_heads = pred_num_attention_heads
140+
self.pred_num_hidden_layers = pred_num_hidden_layers
141+
self.pred_num_mask_tokens = pred_num_mask_tokens
142+
self.pred_zero_init_mask_tokens = pred_zero_init_mask_tokens
143+
self.pred_mlp_ratio = pred_mlp_ratio
144+
145+
146+
__all__ = ["VJEPA2Config"]

0 commit comments

Comments
 (0)