-
Notifications
You must be signed in to change notification settings - Fork 655
CLIP Text Encoder #1969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
CLIP Text Encoder #1969
Changes from 16 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
f40d879
CLIP tokenizer and text encoder
calvinpelletier 5dcf0d0
Merge remote-tracking branch 'origin/main' into clip_text
calvinpelletier 0a070af
formatting
calvinpelletier 8334463
switching to hf vocab file
calvinpelletier d43f5b0
remove dependency on ftfy
calvinpelletier a3c90ac
clip text encoder unit test
calvinpelletier 1dbe939
address comments
calvinpelletier e6b3d19
move __call__
calvinpelletier d501903
merge
calvinpelletier c4e700b
addressing comments
calvinpelletier d5b7f98
Merge remote-tracking branch 'origin/main' into clip_text
calvinpelletier 5aa7c9f
moving quickgelu
calvinpelletier c914aa0
Merge remote-tracking branch 'origin/main' into clip_text
calvinpelletier 69a5a16
addressing comments and making tokenizer more efficient
calvinpelletier 5fe86ae
type hints
calvinpelletier 4c6ef70
tokenizer __call__
calvinpelletier 3baea1c
addressing comments
calvinpelletier bc867ab
Merge remote-tracking branch 'origin/main' into clip_text
calvinpelletier 10c1b0d
configurable eot token
calvinpelletier ec75cae
docstring
calvinpelletier c215690
fix unit test
calvinpelletier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchtune.models.clip._component_builders import clip_text_encoder | ||
from torchtune.training.seed import set_seed | ||
|
||
VOCAB_SIZE = 512 | ||
MAX_SEQ_LEN = 77 | ||
BSZ = 2 | ||
EMBED_DIM = 4 | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random(): | ||
set_seed(0) | ||
|
||
|
||
class TestClipTextEncoder: | ||
@pytest.fixture | ||
def model(self): | ||
model = clip_text_encoder( | ||
vocab_size=VOCAB_SIZE, | ||
max_seq_len=MAX_SEQ_LEN, | ||
embed_dim=EMBED_DIM, | ||
num_heads=2, | ||
num_layers=2, | ||
) | ||
|
||
for param in model.parameters(): | ||
param.data.uniform_(0, 1) | ||
|
||
return model | ||
|
||
@pytest.fixture | ||
def inputs(self): | ||
return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN)) | ||
|
||
def test_forward(self, model, inputs): | ||
actual = model(inputs) | ||
expected = torch.tensor( | ||
[[0.2195, 1.3941, 0.6295, -0.1026], [0.2418, 1.4928, 0.6177, -0.0863]] | ||
) | ||
assert actual.shape == (BSZ, EMBED_DIM) | ||
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) | ||
|
||
def test_backward(self, model, inputs): | ||
y = model(inputs) | ||
loss = y.mean() | ||
loss.backward() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import pytest | ||
|
||
from tests.common import ASSETS | ||
from torchtune.models.clip._model_builders import clip_tokenizer | ||
|
||
|
||
class TestCLIPTokenizer: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt") | ||
|
||
def test_encoding(self, tokenizer): | ||
texts = [ | ||
"a cow jumping over the moon", | ||
"a helpful AI assistant", | ||
] | ||
correct_tokens = [ | ||
[ | ||
2416, | ||
320, | ||
66, | ||
78, | ||
342, | ||
73, | ||
669, | ||
79, | ||
515, | ||
326, | ||
1190, | ||
337, | ||
673, | ||
324, | ||
76, | ||
819, | ||
333, | ||
2417, | ||
], | ||
[2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417], | ||
] | ||
for text, correct in zip(texts, correct_tokens): | ||
tokens = tokenizer.encode(text) | ||
assert tokens == correct | ||
|
||
def test_decoding(self, tokenizer): | ||
text = "this is torchtune" | ||
decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" | ||
assert decoded_text == tokenizer.decode(tokenizer.encode(text)) | ||
|
||
def test_call(self, tokenizer): | ||
sample = {"text": "hello world"} | ||
sample = tokenizer(sample) | ||
assert "text" not in sample | ||
assert "tokens" in sample |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtune.models.convert_weights import get_mapped_key | ||
|
||
# state dict key mappings from HF's format to torchtune's format | ||
_FROM_HF = { | ||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight", | ||
"text_model.embeddings.position_embedding.weight": "position_embedding", | ||
"text_model.encoder.layers.{}.layer_norm1.weight": "layers.{}.sa_norm.weight", | ||
"text_model.encoder.layers.{}.layer_norm1.bias": "layers.{}.sa_norm.bias", | ||
"text_model.encoder.layers.{}.layer_norm2.weight": "layers.{}.mlp_norm.weight", | ||
"text_model.encoder.layers.{}.layer_norm2.bias": "layers.{}.mlp_norm.bias", | ||
"text_model.encoder.layers.{}.mlp.fc1.weight": "layers.{}.mlp.w1.weight", | ||
"text_model.encoder.layers.{}.mlp.fc1.bias": "layers.{}.mlp.w1.bias", | ||
"text_model.encoder.layers.{}.mlp.fc2.weight": "layers.{}.mlp.w2.weight", | ||
"text_model.encoder.layers.{}.mlp.fc2.bias": "layers.{}.mlp.w2.bias", | ||
"text_model.encoder.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", | ||
"text_model.encoder.layers.{}.self_attn.q_proj.bias": "layers.{}.attn.q_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", | ||
"text_model.encoder.layers.{}.self_attn.k_proj.bias": "layers.{}.attn.k_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", | ||
"text_model.encoder.layers.{}.self_attn.v_proj.bias": "layers.{}.attn.v_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.out_proj.bias": "layers.{}.attn.output_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.out_proj.weight": "layers.{}.attn.output_proj.weight", | ||
"text_model.final_layer_norm.weight": "final_norm.weight", | ||
"text_model.final_layer_norm.bias": "final_norm.bias", | ||
} | ||
|
||
_IGNORE = { | ||
"logit_scale", | ||
"text_model.embeddings.position_ids", | ||
"text_projection.weight", | ||
"visual_projection.weight", | ||
} | ||
|
||
|
||
def clip_text_hf_to_tune(state_dict): | ||
converted_state_dict = {} | ||
for key, value in state_dict.items(): | ||
if key.startswith("vision_model.") or key in _IGNORE: | ||
continue | ||
new_key = get_mapped_key(key, _FROM_HF) | ||
converted_state_dict[new_key] = value | ||
return converted_state_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.