Skip to content

Support Kosmos-2.5 #31711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 441 commits into
base: main
Choose a base branch
from
Open

Support Kosmos-2.5 #31711

wants to merge 441 commits into from

Conversation

tic-top
Copy link

@tic-top tic-top commented Jun 29, 2024

What does this PR do?

#30877 Implementation of Kosmos-2.5 in transformers.
https://huggingface.co/kirp/kosmos2_5/blob/main/README.md

Usage

from PIL import Image
import requests
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoConfig
import re

repo = "kirp/kosmos2_5"
device = "cuda:0"
config = AutoConfig.from_pretrained(repo)

NAME = {
    "f" : "flash_attention_2",
    "s" : "sdpa",
    "e" : "eager",
}

# all sdpa fp16
dtype = torch.float16
config._attn_implementation = NAME["s"]
config.vision_config._attn_implementation = NAME["s"]
config.text_config._attn_implementation = NAME["s"]

# # all sdpa fp16
# dtype = torch.float16
# config._attn_implementation = NAME["s"]
# config.text_config._attn_implementation = NAME["s"]
# config.vision_config._attn_implementation = NAME["s"]

# # all eager bf16
# dtype = torch.bfloat16
# config._attn_implementation = NAME["e"]
# config.text_config._attn_implementation = NAME["e"]
# config.vision_config._attn_implementation = NAME["e"]


model = AutoModelForVision2Seq.from_pretrained(repo, device_map = device, torch_dtype=dtype, config=config)
processor = AutoProcessor.from_pretrained(repo)

url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<ocr>" # <md>

inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width

inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)

generated_ids = model.generate(
    **inputs,
    max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

def postprocess(y, scale_height, scale_width):
    y = y.replace(prompt, "")
    if "<md>" in prompt:
        return y
    pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
    bboxs_raw = re.findall(pattern, y)
    lines = re.split(pattern, y)[1:]
    bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
    bboxs = [[int(j) for j in i] for i in bboxs]
    info = ""
    for i in range(len(lines)):
        box = bboxs[i]
        x0, y0, x1, y1 = box
        if not (x0 >= x1 or y0 >= y1):
            x0 = int(x0 * scale_width)
            y0 = int(y0 * scale_height)
            x1 = int(x1 * scale_width)
            y1 = int(y1 * scale_height)
            info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
    return info

output_text = postprocess(generated_text[0], scale_height, scale_width)
print(output_text)

@amyeroberts
Copy link
Collaborator

cc @ydshieh

@ydshieh ydshieh self-assigned this Jul 1, 2024
@ydshieh
Copy link
Collaborator

ydshieh commented Jul 9, 2024

Thanks a lot for this hard work @tic-top. This is going to benefit the community 🤗! Check this tomorrow!

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 10, 2024

I will push an empty commit to trigger a CI running on GPU 🙏

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. Apologized for being late.

My left one quick question and I will focus on reviewing this PR tomorrow.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 16, 2024

No worry about the failing CI above. It's fine, and I checked the tests running on a A10 that is ✅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for me): I will revert this

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/kosmos_2_overview.jpg"
alt="drawing" width="600"/>

<small> Overview of tasks that KOSMOS-2 can handle. Taken from the <a href="https://arxiv.org/abs/2306.14824">original paper</a>. </small>
Copy link
Collaborator

@ydshieh ydshieh Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't keep the above line image, we should remove this too. Otherwise KOSMOS-2 -> KOSMOS-2.5

Comment on lines 35 to 81
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
import re
repo = "microsoft/kosmos-2.5"
device = "cuda:0"
dtype = torch.bfloat16
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<ocr>" # <md>
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
def postprocess(y, scale_height, scale_width):
y = y.replace(prompt, "")
if "<md>" in prompt:
return y
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
bboxs_raw = re.findall(pattern, y)
lines = re.split(pattern, y)[1:]
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
bboxs = [[int(j) for j in i] for i in bboxs]
info = ""
for i in range(len(lines)):
box = bboxs[i]
x0, y0, x1, y1 = box
if not (x0 >= x1 or y0 >= y1):
x0 = int(x0 * scale_width)
y0 = int(y0 * scale_height)
x1 = int(x1 * scale_width)
y1 = int(y1 * scale_height)
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
return info
output_text = postprocess(generated_text[0], scale_height, scale_width)
print(output_text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (not necessary)

Might be nice / interesting to refer to

https://github.com/microsoft/unilm/blob/master/kosmos-2.5/draw_bbox.py

and attach a screenshot of the output images.

Copy link
Author

@tic-top tic-top Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is easier to use.
The python file above need to convert the str to json first, then draw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But if I understand correctly, this only gives the string, but people are more interested to see the final images with bounding boxes or the structured MD layout.

I am not saying to use draw_bbox.py in this documentation. Just mention that there is a such file to draw things and give the link as a reference.

If you have any consideration not to mention, I am OK not to have it here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like How to use?

Comment on lines +287 to +292
seqlen = self.model_tester.text_model_tester.seq_length
inputs_dict["input_ids"] = torch.arange(seqlen, device=torch_device).unsqueeze(0).expand(bs, seqlen)
inputs_dict["input_ids"] = inputs_dict["input_ids"] % self.model_tester.text_model_tester.vocab_size
inputs_dict["attention_mask"] = torch.ones((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"] = torch.zeros((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"][:, : self.model_tester.latent_query_num] = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary only to adjust the batch size? I think the default value will give the same batch size.

(For Kosmos-2, I don't need this extra block to adjust anything)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I test it, it returns unexpected result.

Comment on lines 561 to 611
class Kosmos2_5ModelIntegrationTest(unittest.TestCase):
def run_example(self, prompt, image, model, processor):
print("Prompt:", prompt)
inputs = processor(text=prompt, images=image, return_tensors="pt")
_, _ = inputs.pop("height"), inputs.pop("width")
inputs = {k: v.to(torch_device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(model.dtype)

generation_outputs = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_ids = generation_outputs
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

return generated_ids, generated_text

def test_receipt_image_ocr(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<ocr>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)

EXPECTED_TEXT = [
"<ocr><bbox><x_53><y_573><x_69><y_606></bbox>1\n<bbox><x_79><y_573><x_464><y_611></bbox>[REG] BLACK SAKURA\n<bbox><x_690><y_569><x_810><y_606></bbox>45,455\n<bbox><x_53><y_614><x_69><y_648></bbox>1\n<bbox><x_79><y_614><x_468><y_650></bbox>COOKIE DOH SAUCES\n<bbox><x_788><y_609><x_812><y_644></bbox>0\n<bbox><x_50><y_658><x_69><y_693></bbox>1\n<bbox><x_79><y_658><x_358><y_693></bbox>NATA DE COCO\n<bbox><x_790><y_652><x_814><y_687></bbox>0\n<bbox><x_31><y_742><x_820><y_781></bbox>Sub Total 45,455\n<bbox><x_27><y_781><x_822><y_827></bbox>PB1 (10%) 4,545\n<bbox><x_27><y_826><x_824><y_872></bbox>Rounding 0\n<bbox><x_24><y_872><x_827><y_921></bbox>Total 50,000\n<bbox><x_17><y_1056><x_836><y_1108></bbox>Card Payment 50,000\n"
]

self.assertListEqual(generated_text, EXPECTED_TEXT)

def test_receipt_image_md(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<md>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)
print(generated_text)
EXPECTED_TEXT = [
"<md>- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"
]
self.assertListEqual(generated_text, EXPECTED_TEXT)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we skip some SDPA/Flash attention tests above 🙏 ?

Could you have them in the integration tests? The tests are likely identical, but just using Pure/SDPA/Flash attention masks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Since our CI is still using T4 runner, and non of these 3 tests pass (GPU OOM), I am thinking to reduce max_new_tokens=1024, to something smaller.

Do you have any comment about this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think that makes sense too, I can run on T4 and update the expected output values on my own side

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it turns out that GPU OOM happens already at vision_model_output = self.vision_model. So reducing max_new_tokens won't work here. Forget about my above comment.

Comment on lines 31 to 38
@require_vision
class LlavaProcessorTest(unittest.TestCase):
def test_can_load_various_tokenizers(self):
# for checkpoint in ["microsoft/kosmos-2.5", "microsoft/kosmos-2.5"]:
for checkpoint in ["kirp/kosmos2_5"]:
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be similar to Kosmos2ProcessorTest.

The most important is test_full_processor.

I did a very extensive tests there, but you can keep it simple.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tic-top It looks very good! A few nits comments + a few things for the tests.

I will have to continue to finalize the review tomorrow but I submit what I have so far.

One important point to me is about the comment I left for the model Integration tests.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 17, 2024

Ah, I forgot. We should also have a test file (class Kosmos2_5ImageProcessorTest) for class Kosmos2_5ImageProcessor
(similar to Pix2StructImageProcessingTest for Pix2StructImageProcessor)

It should be easy: just copy Pix2StructImageProcessingTest and only a few changes to do.

(But for this, you can wait - I still have a few things to review tomorrow)

Comment on lines 35 to 81
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
import re
repo = "microsoft/kosmos-2.5"
device = "cuda:0"
dtype = torch.bfloat16
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<ocr>" # <md>
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
def postprocess(y, scale_height, scale_width):
y = y.replace(prompt, "")
if "<md>" in prompt:
return y
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
bboxs_raw = re.findall(pattern, y)
lines = re.split(pattern, y)[1:]
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
bboxs = [[int(j) for j in i] for i in bboxs]
info = ""
for i in range(len(lines)):
box = bboxs[i]
x0, y0, x1, y1 = box
if not (x0 >= x1 or y0 >= y1):
x0 = int(x0 * scale_width)
y0 = int(y0 * scale_height)
x1 = int(x1 * scale_width)
y1 = int(y1 * scale_height)
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
return info
output_text = postprocess(generated_text[0], scale_height, scale_width)
print(output_text)
Copy link
Author

@tic-top tic-top Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is easier to use.
The python file above need to convert the str to json first, then draw.

@@ -1149,6 +1154,7 @@
_import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
_import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"])
_import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"])
_import_structure["models.kosmos2_5"].extend(["Kosmos2_5ImageProcessor", "Kosmos2_5Processor"])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I add it?

@@ -5821,6 +5839,7 @@
from .models.idefics2 import Idefics2ImageProcessor
from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor
from .models.instructblipvideo import InstructBlipVideoImageProcessor
from .models.kosmos2_5 import Kosmos2_5ImageProcessor, Kosmos2_5Processor
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where?

Comment on lines +287 to +292
seqlen = self.model_tester.text_model_tester.seq_length
inputs_dict["input_ids"] = torch.arange(seqlen, device=torch_device).unsqueeze(0).expand(bs, seqlen)
inputs_dict["input_ids"] = inputs_dict["input_ids"] % self.model_tester.text_model_tester.vocab_size
inputs_dict["attention_mask"] = torch.ones((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"] = torch.zeros((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"][:, : self.model_tester.latent_query_num] = 1
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I test it, it returns unexpected result.

Comment on lines 31 to 38
@require_vision
class LlavaProcessorTest(unittest.TestCase):
def test_can_load_various_tokenizers(self):
# for checkpoint in ["microsoft/kosmos-2.5", "microsoft/kosmos-2.5"]:
for checkpoint in ["kirp/kosmos2_5"]:
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines 561 to 611
class Kosmos2_5ModelIntegrationTest(unittest.TestCase):
def run_example(self, prompt, image, model, processor):
print("Prompt:", prompt)
inputs = processor(text=prompt, images=image, return_tensors="pt")
_, _ = inputs.pop("height"), inputs.pop("width")
inputs = {k: v.to(torch_device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(model.dtype)

generation_outputs = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_ids = generation_outputs
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

return generated_ids, generated_text

def test_receipt_image_ocr(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<ocr>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)

EXPECTED_TEXT = [
"<ocr><bbox><x_53><y_573><x_69><y_606></bbox>1\n<bbox><x_79><y_573><x_464><y_611></bbox>[REG] BLACK SAKURA\n<bbox><x_690><y_569><x_810><y_606></bbox>45,455\n<bbox><x_53><y_614><x_69><y_648></bbox>1\n<bbox><x_79><y_614><x_468><y_650></bbox>COOKIE DOH SAUCES\n<bbox><x_788><y_609><x_812><y_644></bbox>0\n<bbox><x_50><y_658><x_69><y_693></bbox>1\n<bbox><x_79><y_658><x_358><y_693></bbox>NATA DE COCO\n<bbox><x_790><y_652><x_814><y_687></bbox>0\n<bbox><x_31><y_742><x_820><y_781></bbox>Sub Total 45,455\n<bbox><x_27><y_781><x_822><y_827></bbox>PB1 (10%) 4,545\n<bbox><x_27><y_826><x_824><y_872></bbox>Rounding 0\n<bbox><x_24><y_872><x_827><y_921></bbox>Total 50,000\n<bbox><x_17><y_1056><x_836><y_1108></bbox>Card Payment 50,000\n"
]

self.assertListEqual(generated_text, EXPECTED_TEXT)

def test_receipt_image_md(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<md>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)
print(generated_text)
EXPECTED_TEXT = [
"<md>- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"
]
self.assertListEqual(generated_text, EXPECTED_TEXT)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Author

@tic-top tic-top left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I add processor?

repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(
repo, device_map=torch_device, torch_dtype=dtype
) # , attn_implementation="eager")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when not specified, and sdpa is available, it will actually use sdpa. Hence we have to specify "eager" in order to test "eager"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a commit for this.

]
self.assertListEqual(generated_text, EXPECTED_TEXT)

def test_FA2(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need

    @require_flash_attn
    @require_torch_gpu
    @pytest.mark.flash_attn_test
    @slow

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(before def test_FA2(self):)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed the change for this.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some more comments. I have to stop in the middle as there is one thing I want to mention in particular.

)

batch_size, seq_len = input.input_ids.shape
additional_tokens = [0, 100283] + [0] * 2048 + [100284]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to have a name of such special tokens, i.e. what 100283 and 100284 mean.
And the value should not be hardcoded. Something like

boi_token_id = tokenizer(boi_token)
eoi_token_id = tokenizer(eoi_token)
additional_tokens = [0, boi_token_id ] + [0] * 2048 + [eoi_token_id ]

would be fine (adjust the names of course)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is it possible to replace 2048 with something not hardcoded?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2048 is fixed

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this work @tic-top ! I am sure the community would be happy to have this available!

# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.
@ydshieh
Copy link
Collaborator

ydshieh commented May 7, 2025

run-slow: kosmos2_5

Copy link
Contributor

github-actions bot commented May 7, 2025

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/kosmos2_5']
quantizations: [] ...

@ydshieh
Copy link
Collaborator

ydshieh commented May 7, 2025

run-slow: kosmos2_5

Copy link
Contributor

github-actions bot commented May 7, 2025

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/kosmos2_5']
quantizations: [] ...

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We need to use a XXCrossAttention layer, a Basemodel (at the cost of not exposing everything) otherwise good to go


# use encoder_hidden_states if cross attention
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird part about this is that you end upa always computing the cross k and v, never using the cache which is not standard !

Comment on lines +1259 to +1261
past_key_value=None,
attention_mask=None,
output_attentions=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that the past is not used, I would be in favor of having a Kosmos2_5CrossAttention class that shows this


**OCR Task:** For usage instructions, please refer to [ocr.py](https://huggingface.co/microsoft/kosmos-2.5/blob/main/ocr.py).


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is missing some snippets about how to use for example extra bboxes and use post processor to plot boxes on the image

@ydshieh
Copy link
Collaborator

ydshieh commented May 8, 2025

Thank you @ArthurZucker for the review.

XXCrossAttention layer, a Basemodel

At this stage, I tried to not changing the modeling code that would require the changes in model weights and to be recreate/re-uploaded.

Currently, the cross attention in Kosmos-2.5 just reuse the TextAttention with is_causal=False.

weird part about this is that you end upa always computing the cross k and v, never using the cache which is not standard !

During generate, the image_to_text_projection (and vision_model) computes only once under if image_embeds is None, so it is OK.

I agree that use TextAttention here is a bit strange. I will try to use XXCrossAttention (assuming no model weights changes) 🙏

@ydshieh
Copy link
Collaborator

ydshieh commented May 8, 2025

Regarding

#31711 (comment)

@zucchini-nlp is this something we discussed before? Could one of you share some more details to me that what I could do? There is already a base model Kosmos2_5Model, but Kosmos2_5ForConditionalGeneration is not in standard form that @zucchini-nlp mentioned to me before, but it will require model weights changes 😢

@zucchini-nlp
Copy link
Member

Yes, I would love to have it in standard form if possible. For vLLM it is not a blocker, because we are anyway messing with state dict keys by replacing prefixes, and Kosmos-2.5 will be one of the replaced models

Internally we talked with Yih-Dar, converting models weights again might be painful and I am not sure if authors accept it (related to internal thread on new model addition with weight conversion). Therefore, I won't push much and happy as is given that PR is open since last year 💀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.