Skip to content

Conversation

@Swastik-Swarup-Dash
Copy link

Expands class_pos_embed to match batch size #34611

@MHRDYN7
Copy link
Contributor

MHRDYN7 commented Dec 3, 2024

@Swastik-Swarup-Dash sorry for following up late on #34611. I think the following changes are needed as well
class_pos_embed_expanded = jnp.repeat(class_pos_embed[jnp.newaxis, :], hidden_states.shape[0], axis=0)
patch_pos_embed_expanded = jnp.repeat(patch_pos_embed, hidden_states.shape[0], axis=0)
return jnp.concatenate([class_pos_embed_expanded, patch_pos_embed_expanded], axis=1)
If the patch position embeds are not repeated for the batch_size, the concatenation afterwards can't happen. You can test this with a sample example script. The same thing could also be done using jnp.tile I guess.

@MHRDYN7
Copy link
Contributor

MHRDYN7 commented Dec 3, 2024

Could you also please add the following in the example docstring section
>>> model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")
to >>> model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base", from_pt=True)
and
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
to >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt = True)

Seems like I missed this important detail while implementing the conversion. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants