Skip to content
Merged
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
6c1c472
adding model and conversion scripts
koustuvsinha Jun 3, 2025
103bde8
add imports to test vjepa conversion
koustuvsinha Jun 3, 2025
d09c1c0
fix imports and make conversion work
koustuvsinha Jun 3, 2025
3fd6bf2
fix computation for short side
koustuvsinha Jun 3, 2025
aae79f7
replace attention with library attention function
koustuvsinha Jun 4, 2025
41957a6
cleanup more attention classes
koustuvsinha Jun 4, 2025
f46df32
remove config overrides
koustuvsinha Jun 4, 2025
174ba39
add test cases, fix some of the failing ones
koustuvsinha Jun 4, 2025
3e41279
fix the model outputs
koustuvsinha Jun 4, 2025
8711628
fix outputs of the model per review
koustuvsinha Jun 4, 2025
da7f76a
fix too big model test case
koustuvsinha Jun 4, 2025
beba328
Merge remote-tracking branch 'upstream/main' into koustuvs/oss
yonigozlan Jun 4, 2025
d54b0a4
Merge branch 'koustuvs/oss' of https://github.com/huggingface/new-mod…
yonigozlan Jun 4, 2025
db6b5e3
fix styling __init__.py
yonigozlan Jun 4, 2025
df77afe
fix initialization test
koustuvsinha Jun 5, 2025
239b30f
remove all asserts per review
koustuvsinha Jun 6, 2025
dd4850d
update sorting unsorting logic as per feedback
koustuvsinha Jun 6, 2025
d8b18b1
remove is_video per review
koustuvsinha Jun 6, 2025
eb955bf
remove another is_video segment
koustuvsinha Jun 6, 2025
a382ebb
remove unwanted stuff
koustuvsinha Jun 6, 2025
de97de0
small fixes
koustuvsinha Jun 6, 2025
30f6feb
add docstrings for the model
koustuvsinha Jun 6, 2025
f5a07b2
revert adding vjepa2 config here
koustuvsinha Jun 6, 2025
d90da4d
update styling
koustuvsinha Jun 6, 2025
238f8f3
add config docstrings (wip)
koustuvsinha Jun 6, 2025
9bc533d
fix dpr issue
koustuvsinha Jun 7, 2025
0354e0c
removed test failing issues
koustuvsinha Jun 7, 2025
fdb6697
update styles
koustuvsinha Jun 7, 2025
e203d41
merge predictor configs into main config
koustuvsinha Jun 7, 2025
3c60022
remove processing code, add video processor
koustuvsinha Jun 9, 2025
3879772
remove permute which is not necessary now
koustuvsinha Jun 9, 2025
ac0e97b
fix styles
koustuvsinha Jun 9, 2025
6ba6f46
updated vjepa2 to be in video_processing_auto
koustuvsinha Jun 9, 2025
0a98f5f
update comment for preprocessing
koustuvsinha Jun 9, 2025
4394345
test integration test and fix the outputs
koustuvsinha Jun 9, 2025
436fce5
update test values, change test to look at repeated frames for a give…
koustuvsinha Jun 9, 2025
9752cec
add a simple video processing test
koustuvsinha Jun 9, 2025
9f004fa
refactoring pixel_values_videos and upload ckpts to original
koustuvsinha Jun 9, 2025
8409e83
fix torch_fx test cases
koustuvsinha Jun 9, 2025
6b8c776
remove unused config
koustuvsinha Jun 9, 2025
5fc0f9d
add all config docstrings
koustuvsinha Jun 9, 2025
a966ca2
add more integration tests
koustuvsinha Jun 9, 2025
824f734
add basic doc
koustuvsinha Jun 10, 2025
0f8be62
revert unwanted styling changes
yonigozlan Jun 10, 2025
6fa271b
working make fixup
yonigozlan Jun 10, 2025
9b5d9dd
Fix model_type in config
qubvel Jun 10, 2025
7ce0d8c
update attention implementation to fit new hf standards
yonigozlan Jun 10, 2025
6522b7f
fix the preprocessing logic, ensure it matches the original model
koustuvsinha Jun 10, 2025
5042ec8
Merge pull request #3 from huggingface/standardize-state-dict
koustuvsinha Jun 10, 2025
eed77e7
remove use_rope logic, cleanup
yonigozlan Jun 10, 2025
a04edb2
fix docstrings
yonigozlan Jun 10, 2025
f1e8605
Further cleanup, update doc
yonigozlan Jun 10, 2025
525eaae
Fix model prefix
yonigozlan Jun 10, 2025
e5faf69
fix get_vision_features
koustuvsinha Jun 10, 2025
79aebe5
VJEPA2Embeddings style refactor
qubvel Jun 10, 2025
cffcbf3
nit, style comment
qubvel Jun 10, 2025
1e3fa77
change modules default values
qubvel Jun 10, 2025
5479538
Only `str` activation in config
qubvel Jun 10, 2025
4a320a9
GradientCheckpointingLayer
qubvel Jun 10, 2025
2ac05f3
fixup
qubvel Jun 10, 2025
15c1f95
fix conversion script
koustuvsinha Jun 10, 2025
24bb730
Remove return_dict
qubvel Jun 10, 2025
71df552
remove None return typehint
qubvel Jun 10, 2025
a0242be
Refactor VJEPA2Layer, remove use_SiLU
qubvel Jun 10, 2025
bb574e4
Fix fx tests
qubvel Jun 10, 2025
1b7e8f8
dpr -> drop_path_rates
qubvel Jun 10, 2025
21e956d
move *ModelOutput on top
qubvel Jun 10, 2025
0a7a13f
format docs bit
qubvel Jun 10, 2025
c1365e3
update docs
koustuvsinha Jun 10, 2025
4cc33e2
update docs
koustuvsinha Jun 10, 2025
b77105d
update doc example
koustuvsinha Jun 10, 2025
38d25bf
remove prune_heads from model
yonigozlan Jun 11, 2025
918d302
remove unused config params
yonigozlan Jun 11, 2025
4275757
refactor embed signature
qubvel Jun 11, 2025
ddd1c36
Add vjepa to docs
qubvel Jun 11, 2025
b8f67a4
Fix config docstring
qubvel Jun 11, 2025
6d6aadc
update defaults
qubvel Jun 11, 2025
9661409
Update docs/source/en/model_doc/vjepa2.md
qubvel Jun 11, 2025
f67a4d7
Update docs/source/en/model_doc/vjepa2.md
qubvel Jun 11, 2025
2908c5c
Fix import
qubvel Jun 11, 2025
91028fc
Merge branch 'main' into koustuvs/oss
qubvel Jun 11, 2025
fb95d52
Min refactoring
qubvel Jun 11, 2025
39aafc3
Update HUB_SOURCE and HUB_REPO in conversion script
qubvel Jun 11, 2025
7e660ef
Add missing headers
qubvel Jun 11, 2025
9061b3a
VJEPA -> V-JEPA in docs
qubvel Jun 11, 2025
284b3e5
Add image to doc
yonigozlan Jun 11, 2025
a7d750b
fix style
qubvel Jun 11, 2025
f80ff37
Merge branch 'koustuvs/oss' of https://github.com/qubvel/transformers…
yonigozlan Jun 11, 2025
1058246
fix init weights
yonigozlan Jun 11, 2025
a85656e
change checkpoint name in modeling tests
yonigozlan Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/source/en/model_doc/vjepa2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->


<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
</div>
</div>

# V-JEPA 2

V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, VJEPA 2 attains state-of-the-art performance on motion understanding and human action anticpation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.

You can find all original VJEPA2 checkpoints under the [VJEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) collection.


This model was contributed by [koustuvs](https://huggingface.co/koustuvs) and [yonigozlan](https://huggingface.co/yonigozlan). The original code can be found [here](https://github.com/facebookresearch/vjepa2).

## Usage example

The snippet below shows how to load the VJEPA 2 model using `AutoModel` class.

```py
import torch
from torchcodec.decoders import VideoDecoder
import numpy as np

processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just flagging to confirm this be the final repo name @merveenoyan @ariG23498 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is the final repo name!

model = AutoModel.from_pretrained(
"facebook/vjepa2-vitl-fpc64-256",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)

video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"

vr = VideoDecoder(video_url)
frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
video = processor(video, return_tensors="pt").to(model.device)
outputs = model(**video)

# VJEPA-2 encoder outputs, same as calling `model.get_vision_features()`
encoder_outputs = outputs.last_hidden_state

# VJEPA-2 predictor outputs
predictor_outputs = outputs.predictor_output.last_hidden_state
```

## VJEPA2Config

[[autodoc]] VJEPA2Config

## VJEPA2Model

[[autodoc]] VJEPA2Model
- forward

## VJEPA2VideoProcessor

[[autodoc]] VJEPA2VideoProcessor
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@
from .vitpose_backbone import *
from .vits import *
from .vivit import *
from .vjepa2 import *
from .wav2vec2 import *
from .wav2vec2_bert import *
from .wav2vec2_conformer import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@
("vitpose_backbone", "VitPoseBackboneConfig"),
("vits", "VitsConfig"),
("vivit", "VivitConfig"),
("vjepa2", "VJEPA2Config"),
("wav2vec2", "Wav2Vec2Config"),
("wav2vec2-bert", "Wav2Vec2BertConfig"),
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
Expand Down Expand Up @@ -748,6 +749,7 @@
("vitpose_backbone", "ViTPoseBackbone"),
("vits", "VITS"),
("vivit", "ViViT"),
("vjepa2", "VJEPA2Model"),
("wav2vec2", "Wav2Vec2"),
("wav2vec2-bert", "Wav2Vec2-BERT"),
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@
("vitdet", "VitDetModel"),
("vits", "VitsModel"),
("vivit", "VivitModel"),
("vjepa2", "VJEPA2Model"),
("wav2vec2", "Wav2Vec2Model"),
("wav2vec2-bert", "Wav2Vec2BertModel"),
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/video_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
("qwen2_vl", "Qwen2VLVideoProcessor"),
("smolvlm", "SmolVLMVideoProcessor"),
("video_llava", "VideoLlavaVideoProcessor"),
("vjepa2", "VJEPA2VideoProcessor"),
]
)

Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/vjepa2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_vjepa2 import *
from .modeling_vjepa2 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
132 changes: 132 additions & 0 deletions src/transformers/models/vjepa2/configuration_vjepa2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""VJEPA 2 model configuration"""

from ...configuration_utils import PretrainedConfig


class VJEPA2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VJEPA2Model`]. It is used to instantiate an
VJEPA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VJEPA2
[facebook/vjepa2-vitl-fpc64-256](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256) architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
crop_size (`int`, *optional*, defaults to 256):
Input resolution of the model
frames_per_clip (`int`, *optional*, defaults to 64):
The number of frames the model has been pretrained with. Does not impact inference.
tubelet_size (`int`, *optional*, defaults to 2):
The number of temporal frames used for a single rastor, check paper for more information.
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers
in_chans (`int`, *optional*, defaults to 3):
The number of input channels
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Encoder
num_hidden_layers (`int`, *optional*, defaults to 12):
The number of hidden layers
drop_path_rate (`float`, *optional*, defaults to 0.0):
Stochastic depth rate per sample (when applied in the main path of residual layers).
mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of the hidden size of the MLPs used in Encoder relative to the `hidden_size`.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for attentions.
The dropout probability for all fully connected layers.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
pred_hidden_size (`int`, *optional*, defaults to 384):
Dimensionality of the predictor layers
pred_num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Predictor
pred_num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Predictor
pred_num_mask_tokens (`int`, *optional*, defaults to 10):
Define the number of mask tokens to use in the Predictor
pred_zero_init_mask_tokens (`bool`, *optional*, defaults to `True`):
Initialize the mask tokens in the predictor with 0.
pred_mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of the hidden size of the MLPs used in Predictor relative to the `pred_hidden_size`.

Example:

```python
>>> from transformers import VJEPA2Config, VJEPA2Model

>>> # Initializing a VJEPA2 vjepa2-vitl-fpc64-256 style configuration
>>> configuration = VJEPA2Config()

>>> # Initializing a model (with random weights) from the vjepa2-vitl-fpc64-256 style configuration
>>> model = Dinov2Model(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "vjepa2"

def __init__(
self,
patch_size=16,
crop_size=256,
frames_per_clip=64,
tubelet_size=2,
hidden_size=1024,
in_chans=3,
num_attention_heads=12,
num_hidden_layers=12,
drop_path_rate=0.0,
mlp_ratio=4.0,
layer_norm_eps=1e-6,
qkv_bias=True,
attention_probs_dropout_prob=0.0,
hidden_act="gelu",
initializer_range=0.02,
# predictor params
pred_hidden_size=384,
pred_num_attention_heads=12,
pred_num_hidden_layers=12,
pred_num_mask_tokens=10,
pred_zero_init_mask_tokens=True,
pred_mlp_ratio=4.0,
**kwargs,
):
super().__init__(**kwargs)

self.crop_size = crop_size
self.frames_per_clip = frames_per_clip
self.patch_size = patch_size
self.tubelet_size = tubelet_size
self.hidden_size = hidden_size
self.in_chans = in_chans
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.drop_path_rate = drop_path_rate
self.mlp_ratio = mlp_ratio
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.image_size = crop_size
# predictor params
self.pred_hidden_size = pred_hidden_size
self.pred_num_attention_heads = pred_num_attention_heads
self.pred_num_hidden_layers = pred_num_hidden_layers
self.pred_num_mask_tokens = pred_num_mask_tokens
self.pred_zero_init_mask_tokens = pred_zero_init_mask_tokens
self.pred_mlp_ratio = pred_mlp_ratio


__all__ = ["VJEPA2Config"]
Loading
Loading