Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
97199e3
Make ViT Pooler configurable, so that it is possible to pick the acti…
sebbaur Mar 3, 2025
43639fa
Add documentation and allow functions as activations (instead of just…
sebbaur Mar 3, 2025
80e132a
formatting change
sebbaur Mar 3, 2025
9fe0691
Use ACT2FN
sebbaur Mar 3, 2025
bd2a001
Formatting change
sebbaur Mar 3, 2025
18edafb
Formatting changes
sebbaur Mar 3, 2025
5060f56
force pooler_act to be string
sebbaur Mar 3, 2025
0a2895f
force pooler_act to be string
sebbaur Mar 3, 2025
9e9dfd1
Add configs to OBJECTS_TO_IGNORE to make check_docstrings happy
sebbaur Mar 3, 2025
0d2b485
Making the same change in ijepa to make check_modular_conversion happy
sebbaur Mar 3, 2025
e0d6928
Add IJepaConfig to make CI happy
sebbaur Mar 3, 2025
0881e5f
rename pooler_size to pooler_output_size as defined in the config
sebbaur Mar 3, 2025
e2670cc
typo
sebbaur Mar 3, 2025
0382960
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 3, 2025
c89f69b
revert change to ignore variable
sebbaur Mar 4, 2025
6309d96
Merge branch 'feature/vit_pooler_configurable' of https://github.com/…
sebbaur Mar 4, 2025
6c6a1ac
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 4, 2025
5ffac41
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 4, 2025
ad2aeca
Ran utils/check_docstrings.py --fix_and_overwrite
sebbaur Mar 4, 2025
bc5ef1b
revert unrelated change
sebbaur Mar 4, 2025
2d0071f
remove redundant defaults
sebbaur Mar 4, 2025
3ca67c1
rename self.act -> self.activation
sebbaur Mar 6, 2025
05499eb
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 6, 2025
ab56b63
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 7, 2025
308b5b3
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 10, 2025
7e9bc63
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 10, 2025
eff0bd1
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 11, 2025
8028f62
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 17, 2025
87b9162
tanh activation function in mapping
sebbaur Mar 20, 2025
16dfb96
Merge branch 'feature/vit_pooler_configurable' of https://github.com/…
sebbaur Mar 20, 2025
366edfe
Merge branch 'main' into feature/vit_pooler_configurable
sebbaur Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/transformers/models/deit/configuration_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class DeiTConfig(PretrainedConfig):
Whether to add a bias to the queries, keys and values.
encoder_stride (`int`, *optional*, defaults to 16):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
pooler_output_size (`int`, *optional*):
Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
pooler_act (`str`, *optional*, defaults to `"tanh"`):
The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
supported for Tensorflow.

Example:

Expand Down Expand Up @@ -103,6 +109,8 @@ def __init__(
num_channels=3,
qkv_bias=True,
encoder_stride=16,
pooler_output_size=None,
pooler_act="tanh",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -121,6 +129,8 @@ def __init__(
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.encoder_stride = encoder_stride
self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
self.pooler_act = pooler_act


class DeiTOnnxConfig(OnnxConfig):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,8 @@ def forward(
class DeiTPooler(nn.Module):
def __init__(self, config: DeiTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
self.activation = ACT2FN[config.pooler_act]

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deit/modeling_tf_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,9 @@ def __init__(self, config: DeiTConfig, **kwargs):
super().__init__(**kwargs)

self.dense = keras.layers.Dense(
units=config.hidden_size,
units=config.pooler_output_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
activation=config.pooler_act,
name="dense",
)
self.config = config
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/dpt/configuration_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ class DPTConfig(PretrainedConfig):
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
pooler_output_size (`int`, *optional*):
Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
pooler_act (`str`, *optional*, defaults to `"tanh"`):
The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
supported for Tensorflow.

Example:

Expand Down Expand Up @@ -173,6 +179,8 @@ def __init__(
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
pooler_output_size=None,
pooler_act="tanh",
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -268,6 +276,8 @@ def __init__(
self.auxiliary_loss_weight = auxiliary_loss_weight
self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.semantic_classifier_dropout = semantic_classifier_dropout
self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
self.pooler_act = pooler_act

def to_dict(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,8 +955,8 @@ def forward(
class DPTViTPooler(nn.Module):
def __init__(self, config: DPTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
self.activation = ACT2FN[config.pooler_act]

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/ijepa/configuration_ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class IJepaConfig(PretrainedConfig):
The number of input channels.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
pooler_output_size (`int`, *optional*):
Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
pooler_act (`str`, *optional*, defaults to `"tanh"`):
The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
supported for Tensorflow.
Example:
Expand Down Expand Up @@ -89,6 +95,8 @@ def __init__(
patch_size=16,
num_channels=3,
qkv_bias=True,
pooler_output_size=None,
pooler_act="tanh",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -106,6 +114,8 @@ def __init__(
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
self.pooler_act = pooler_act


__all__ = ["IJepaConfig"]
4 changes: 2 additions & 2 deletions src/transformers/models/ijepa/modeling_ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ def forward(
class IJepaPooler(nn.Module):
def __init__(self, config: IJepaConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
self.activation = ACT2FN[config.pooler_act]

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/vit/configuration_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ class ViTConfig(PretrainedConfig):
Whether to add a bias to the queries, keys and values.
encoder_stride (`int`, *optional*, defaults to 16):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
pooler_output_size (`int`, *optional*):
Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
pooler_act (`str`, *optional*, defaults to `"tanh"`):
The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
supported for Tensorflow.

Example:

Expand Down Expand Up @@ -102,6 +108,8 @@ def __init__(
num_channels=3,
qkv_bias=True,
encoder_stride=16,
pooler_output_size=None,
pooler_act="tanh",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -120,6 +128,8 @@ def __init__(
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.encoder_stride = encoder_stride
self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
self.pooler_act = pooler_act


class ViTOnnxConfig(OnnxConfig):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/vit/modeling_flax_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,18 @@ class FlaxViTPooler(nn.Module):

def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
self.config.pooler_output_size,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.pooler_act]
Comment on lines 420 to +421
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"tanh" is a KeyError here, it does not exist in the mapping


def __call__(self, hidden_states):
cls_hidden_state = hidden_states[:, 0]
cls_hidden_state = self.dense(cls_hidden_state)
return nn.tanh(cls_hidden_state)
return self.activation(cls_hidden_state)


class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/vit/modeling_tf_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,9 +788,9 @@ def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)

self.dense = keras.layers.Dense(
units=config.hidden_size,
units=config.pooler_output_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
activation=config.pooler_act,
name="dense",
)
self.config = config
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,8 +662,8 @@ def forward(
class ViTPooler(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
self.activation = ACT2FN[config.pooler_act]

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
Expand Down