Skip to content

Add Qwen2.5VL support #12402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 27, 2025
Merged

Add Qwen2.5VL support #12402

merged 20 commits into from
Apr 27, 2025

Conversation

HimariO
Copy link
Contributor

@HimariO HimariO commented Mar 15, 2025

Original issue: #11483

Changes

  • Add new gguf key for clip model to support
    • GLU MLP,
    • window attention,
    • RMS norm
  • Updated clip.cpp vision model to incorporate these new components.
  • Modified qwen2_vl_surgery.py and convert_hf_to_gguf.py to support the Qwen2.5VL model.

Model Conversion

The only change in the conversion process compared to Qwen2VL is the addition of the model_type parameter when creating the vision encoder GGUF file. (For the rest of the process and how to build llama-qwen2vl-cli, refer to #10361.)

PYTHONPATH=$PYTHONPATH:$(pwd)/gguf-py python3 examples/llava/qwen2_vl_surgery.py "/path/to/model" --data_type fp16 --model_type "qwen2.5vl"

@github-actions github-actions bot added examples python python script changes labels Mar 15, 2025
@HimariO HimariO mentioned this pull request Mar 16, 2025
4 tasks
@LostRuins
Copy link
Collaborator

I have not converted models with the surgery myself, but I can confirm that those uploaded at https://huggingface.co/samgreen/Qwen2.5-VL-7B-Instruct-GGUF are working correctly with your changes.

@thomasht86
Copy link

Waiting for this! 🙌 🙏

@HimariO HimariO marked this pull request as ready for review April 4, 2025 08:38
@abalmos
Copy link

abalmos commented Apr 12, 2025

@HimariO I am having trouble with the model output just stopping while only partly done answering. I am using something like:

llama-qwen2vl-cli -m ./Qwen25-VL/Qwen2.5-VL-7B-Instruct.gguf --mmproj ./Qwen25-VL/qwen25-vll-vision.gguf --image ./test.png -t 12 --threads 12 --ctx-size 128000 --batch-size 32 -j "{}" --ignore-eos -n -1

Do you have any thoughts on what might be the cause?

@LostRuins
Copy link
Collaborator

Works fine for me, why are you using --ignore-eos? Also instead of setting -n to -1 does it happen if you set it to a large number or does it still stop prematurely?

@abalmos
Copy link

abalmos commented Apr 13, 2025

I am still trying to figure this space out. Some Googling suggested that EOS can sometimes be a problem and that --ignore-eos could help. I now see how that wouldn't help here.

@LostRuins Thanks! You were right, making -n very large allowed for the output to finish. I guess -n -1 does not work for this model @HimariO ?

@HimariO
Copy link
Contributor Author

HimariO commented Apr 13, 2025

@abalmos seems like process_prompt function will set the number of output tokens to 256 if you set it to -1(or just leave it as default). And since qwen2vl-cli is based on llaval-cli, another model will also have this behavior.

@abalmos
Copy link

abalmos commented Apr 13, 2025

@HimariO Thanks. The model did produce output, even with all the config at default, it would just stop too early. The flags I had in my first comment were just all the things I tried changing. Testing again, only the -n flag is actually needed for the JSON output that I was excepting. Based on your last comment, that seems to be understood and make sense.

Everything is working as expected with the default and a -n flag. Thanks!

@HimariO
Copy link
Contributor Author

HimariO commented Apr 14, 2025

@ggerganov, I think this PR is ready for review. Please take a look when you have a moment.

@CoruNethron
Copy link

CoruNethron commented Apr 18, 2025

Tested with this model: https://huggingface.co/ByteDance-Seed/UI-TARS-1.5-7B
All steps, convertion of LLM part, LLM part was quantized to q5_k_m, conversion of vision part and inference.

Results

I had to save prompt to a file, and specify it like:

llama-qwen2vl-cli ... -p "$(cat ./prompt.txt)"

prompt.txt

You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.

## Output Format
```
Thought: ...
Action: ...
```

## Action Space

click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx')
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx')


## Note
- Use English in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.

## User Instruction
Open google and search for Henry Ford previous job title
Scrx

Model answer:

Thought: I've accessed the Google Translate page. To start searching for the previous job title of Henry Ford, I need to enter the text into the left input box. The first step is to click on the input box so that I can type in the keywords.
Action: click(start_box='<|box_start|>(290,286)<|box_end|>')

Thank you, @HimariO ! It's great!

@ggerganov ggerganov requested a review from ngxson April 19, 2025 07:22
@@ -166,6 +167,8 @@ struct clip_hparams {
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size;
std::vector<int32_t> full_attn_layers;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the full_attn_layers is a repeated pattern? AFAIK most models have repeated pattern like N layers sliding window, followed by 1 full layer

If it's the case, I think it will be easier to have a simple integer int32_t n_swa_pattern, SWA means sliding window attention, and it's the same abbreviation used in libllama

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attn_window_size can be renamed to n_swa to align naming scheme with libllama

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen2.5-VL seen to always using 1 full-attention layer every 7 window-attention layer(beside the 32B variant not using window-attention at all). so I think we can replace full_attn_layers with single int value here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vision encoder’s window-attention differs from SWA. It first splits the input image into multiple patches of size attn_window_size * attn_window_size, and then restricts each image token to only interact with other tokens within the same patch during the window-attention layer.

I'm not sure it's a good idea to reuse the n_swa naming here, since n_swa is measured in the number of tokens, whereas attn_window_size is measured in pixels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like "chunk attention" in llama 4, do you have a visualization for it?

Does the mask looks like this (example with window_size=3):

xxx------
xxx------
xxx------
---xxx---
---xxx---
---xxx---
------xxx
------xxx
------xxx

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be clearer if I referred to these as "windows" or "window partitions" instead of "patches," to avoid confusion with the terminology used in other models where one patch usually corresponds to one token.

Based on your explanation, I think the correct terms are:

  • "4 patches" should be "4 slices", because we already using the term "slices" from llava-uhd implementation
  • "84x112 will create (84/14)x(112/14) tokens.", "tokens" here should be "patches", because it's how they call it in the original ViT paper
  • It still unclear how to call the window. So my question is: can these "slices" being processed one-by-one? Or there is a non-causal attention cross all slices somewhere in the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few layers (the full_attn_layers mentioned earlier) in the vision encoder that provide non-causal attention across different slices, allowing patches from different slices to interact with each other.
Because of this, the slices cannot be processed one-by-one, all slices from the same image must be processed together by the vision encoder.

Copy link
Collaborator

@ngxson ngxson Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, I think I'm 80% understand qwen2.5 now.

One final big question, regarding your visualization above:

Here is an actual attention mask created from a 196x140 image and default 112 attn_window_size

What I understand is that the 4 yellow "chunks" are 4 slices, and the size of each yellow "chunk" correspond to the number of tokens of each slice. Because the H and W length of the mask is about 130, I assume that in total, we have about 130 patches.

But I don't see where the value 112 is reflected on the visualization, could you explain this further?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention mask in the visualization above has the size of 140x140, becuase we have (196/14)*(140/14)=140 patches.
The largest yellow chunk in the top-left corresponds to the 112×112 slice, which produces 64 patches, resulting in a 64×64 chunk in the attention mask. Similarly, the second chunk corresponds to the 84×112 slice, and so on for the others.

attn_window_size=112 indirectly limits the maximum size of each yellow chunk to 64x64 by capping the maximum size of image slices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks for the explanation, that sounds clear to me. Let me think about how to organize the code so that it's easy for other people to understand this logic.

Comment on lines 370 to 372
// 1. Initialize backend
ggml_backend_t backend = NULL;
std::string backend_name = "";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I have https://github.com/ngxson/ggml-easy which may help in this case (I planned to support exporting tensor data as image)

Maybe worth exploring how to compile it with ExternalProject, so it can be an optional deps (without including the whole file into llama.cpp like httplib or json.hpp)

Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the most important thing missing is a dedicated enum PROJECTOR_TYPE_QWEN2_5_VL

For the rest, @HimariO if you're busy, I can take over the rest (refactoring the code) if you want.

@ngxson
Copy link
Collaborator

ngxson commented Apr 26, 2025

Merging this once the CI is green. I also added the test I mentioned in the last comment.

Here is the pre-quantized Qwen2.5-VL 3B:

llama-qwen2vl-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF

More quants will be available once convert_hf_to_gguf.py script support the mmproj conversion

@LostRuins
Copy link
Collaborator

I did a quick test with my existing quants but they didn't work - though I see the q2vl surgery file has been changed and I would probably need to reconstruct the mmproj? I will redownload the model and try that again later.

@ngxson ngxson merged commit ca2bb89 into ggml-org:master Apr 27, 2025
51 checks passed
@LostRuins
Copy link
Collaborator

LostRuins commented Apr 27, 2025

Hello @ngxson , the newest PR is not working correctly for me.

I reconverted https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct with the new surgery script, and the resulting mmproj loaded. However, when trying to perform inference I get a segfault.

I then tried your new quants at https://huggingface.co/ggml-org/Qwen2.5-VL-3B-Instruct-GGUF with the same result, also segfault.

Segfault seems to be happening at ggml_nbytes() from this line https://github.com/HimariO/llama.cpp.qwen2.5vl/blob/qwen25-vl/examples/llava/clip.cpp#L3272. Looking closer, I think it should not even be in that branch, previously in @HimariO version it is captured by the if (ctx->has_qwen2vl_merger) { check.

https://github.com/HimariO/llama.cpp.qwen2.5vl/blob/53a15d014f2dd0a1409e49da097dba891c629f6e/examples/llava/clip.cpp#L3142

I tried replacing
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
with
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {

but that is not enough. the inference proceeds, but the output is wrong. I think that has_qwen2vl_merger was used in multiple places, and therefore some of them need to match against PROJECTOR_TYPE_QWEN25VL as well

Qwen2VL, gemma and others still work fine.

@LostRuins
Copy link
Collaborator

LostRuins commented Apr 27, 2025

Also, I think qwen2.5vl 3B model is broken due to unrelated reasons. The output is completely incoherent. So it's not a good example to compare against.

Anyway @ngxson , I did this fix and it seems to work now with the latest 7B q2.5vl quants, possibly overkill but do take a look LostRuins@f8b7dde - at least in practice it seems to work fine.

For those who want to try with older mmproj without reconverting, this ugly hack will allow you to load them as well LostRuins@37060f5 (though im sure its out of scope for lcpp)

Edit: PR for fix #13133

@ngxson
Copy link
Collaborator

ngxson commented Apr 27, 2025

Please note that if you're distributing GGUF from a WIP PR, it's your responsibility to update it. For the same reason, I don't recommend distributing GGUF publicly before merging, unless the PR reaches a final review state.

Comment on lines +3282 to +3292
if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");

const int merge_ratio = 2;
const int attn_window_size = 112;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int grid_window = attn_window_size / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
Copy link
Collaborator

@ngxson ngxson Apr 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok now I realized that this code block looks identical to the one above (which also sets window_idx), and this is the problem why Qwen2.5VL 3B model doesn't work.

Still, I have no idea why this code block is duplicated. I'll try to refactor it, but just for sure @HimariO can you confirm if this is NOT intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After look into this futher, the duplicated block is acutlly used for create idx for rearranging position_ids. Only the ggml_graph_get_tensor and ggml_backend_tensor_set part is redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an updated version that properly separates qwen2vl and qwen2.5vl logic.

Copy link
Collaborator

@ngxson ngxson Apr 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, the whole outer block is duplicated in 2 places. I suppose you did this way because you want to re-use the same code for both qwen2 and qwen2.5

In anyway, I ended up separate the code block for qwen2 and qwen2.5 so it looks more clean. For the code of qwen2, I took the version before this PR, so qwen2 should be correct.

But just to be sure, can you take a look at this file to see if the logic for qwen2.5 still look correct? (The automated test passed btw, so I assume it should be correct)

case PROJECTOR_TYPE_QWEN2VL:
{
const int pw = image_size_width / patch_size;
const int ph = image_size_height / patch_size;
std::vector<int> positions(num_positions * 4);
int ptr = 0;
for (int y = 0; y < ph; y += 2) {
for (int x = 0; x < pw; x += 2) {
for (int dy = 0; dy < 2; dy++) {
for (int dx = 0; dx < 2; dx++) {
positions[ ptr] = y + dy;
positions[ num_patches + ptr] = x + dx;
positions[2 * num_patches + ptr] = y + dy;
positions[3 * num_patches + ptr] = x + dx;
ptr++;
}
}
}
}
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_QWEN25VL:
{
// pw * ph = number of tokens output by ViT after apply patch merger
// ipw * ipw = number of vision token been processed inside ViT
const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
const int iph = image_size_height / patch_size;
std::vector<int> idx (ph * pw);
std::vector<int> inv_idx(ph * pw);
if (use_window_attn) {
const int attn_window_size = 112;
const int grid_window = attn_window_size / patch_size / merge_ratio;
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
int mask_row = 0;
for (int y = 0; y < ph; y += grid_window)
{
for (int x = 0; x < pw; x += grid_window)
{
const int win_h = std::min(grid_window, ph - y);
const int win_w = std::min(grid_window, pw - x);
const int dst_0 = dst;
// group all tokens belong to the same window togather (to a continue range)
for (int dy = 0; dy < win_h; dy++) {
for (int dx = 0; dx < win_w; dx++) {
const int src = (y + dy) * pw + (x + dx);
GGML_ASSERT(src < (int)idx.size());
GGML_ASSERT(dst < (int)inv_idx.size());
idx [src] = dst;
inv_idx[dst] = src;
dst++;
}
}
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
int row_offset = mask_row * (ipw * iph);
std::fill(
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
0.0);
mask_row++;
}
}
}
set_input_i32("window_idx", idx);
set_input_i32("inv_window_idx", inv_idx);
set_input_f32("window_mask", mask);
} else {
for (int i = 0; i < ph * pw; i++) {
idx[i] = i;
}
}
const int mpow = merge_ratio * merge_ratio;
std::vector<int> positions(num_positions * 4);
int ptr = 0;
for (int y = 0; y < iph; y += merge_ratio) {
for (int x = 0; x < ipw; x += merge_ratio) {
for (int dy = 0; dy < 2; dy++) {
for (int dx = 0; dx < 2; dx++) {
auto remap = idx[ptr / mpow];
remap = (remap * mpow) + (ptr % mpow);
positions[ remap] = y + dy;
positions[ num_patches + remap] = x + dx;
positions[2 * num_patches + remap] = y + dy;
positions[3 * num_patches + remap] = x + dx;
ptr++;
}
}
}
}
set_input_i32("positions", positions);
} break;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the logic looks fine to me.

Comment on lines +3154 to +3160
if (use_window_attn) {
const int attn_window_size = 112;
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");

const int grid_window = attn_window_size / patch_size / merge_ratio;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code block is almost identical to the one mentioned in the last comment, but the actual logic is different

pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Apr 28, 2025
* implment vision model architecture, gguf convertor

* handle window attention inputs

* add debug utils

* fix few incorrect tensor memory layout

* move position id remap out of ggml to avoid int32 cuda operations

* cleaning up

* ignore transformers Qwen2_5_xxx type check

* remove not so often use `qwen2vl-cli` debug functions

* remove commented-out code blocks

* fix attn weight scaling after rebase

* add `PROJECTOR_TYPE_QWEN2_5_VL`

* remove `KEY_USE_GLU_MLP`, `KEY_USE_RMS_NORM`

* replace `KEY_FULLATTN_BLK_IDX` with `KEY_WIN_ATTN_PATTERN`

* remove `attn_window_size` from gguf

* fix model conversion

* clean up

* fix merging problem

* add test

---------

Co-authored-by: Xuan Son Nguyen <[email protected]>
pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Apr 28, 2025
@ngxson
Copy link
Collaborator

ngxson commented Apr 30, 2025

I'm testing with the 32B model and I realized that the text model does not work, it responses with @@@@@@@@@... repeatedly. @LostRuins have you tested with Qwen 2.5 VL 32B? Here is our pre-quant: https://huggingface.co/ggml-org/Qwen2.5-VL-32B-Instruct-GGUF

I'm trying with llama-cli btw, no vision here.

@HimariO Also, I think the 32B model does use fullatt_block_indexes. The reason why you see it missing from config.json was because transformers exclude some keys from json if they are the same with default value. I don't know why it only happen with 32B, but I'm pretty sure it's the case here. If the model didn't use that, it would have been an empty array: fullatt_block_indexes: []

@LostRuins
Copy link
Collaborator

LostRuins commented May 1, 2025

Hello @ngxson , I just tried the text model alone, https://huggingface.co/ggml-org/Qwen2.5-VL-32B-Instruct-GGUF/blob/main/Qwen2.5-VL-32B-Instruct-Q4_K_M.gguf, loaded 40 layers with Vulkan backend. And it works perfectly fine.

Tried again with CUDA backend, 40 layers, also no issue.

I did not test flash attention.

I am on Windows 10 x86_64, Nvidia RTX 4090 laptop (driver 566.36 so no coopmat2), Intel i9-13980hx.

Did you use the exact same quant download as above? Can you give me a set of launch args that do not work?

@ngxson
Copy link
Collaborator

ngxson commented May 1, 2025

Yes, I use the exact quant above, the command is llama-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M

I haven't tested with CPU-only, but if that works on CUDA and Vulkan, probably there is a problem with Metal backend

@ggerganov
Copy link
Member

@ngxson The 32B model works on my M2 Studio:

image

@ngxson
Copy link
Collaborator

ngxson commented May 2, 2025

This is how it looks on my system (mac M3 ultra)

image

Note: the llamac is my custom bash macro to run cmake && binary at the same time


with -ngl 0 and -DGGML_METAL=OFF:

image

@LostRuins
Copy link
Collaborator

I can't test for mac, but i can confirm coherence on cuda, vulkan and cpu

@danielhanchen
Copy link
Contributor

@ngxson Interestingly when I convert 72B VL, I get extremely high perplexity values for Qwen 2.5 VL 72B Instruct. I'm getting 20 to 70 weirdly after BF16 conversion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants