Skip to content

clip : fix pixtral on some GPU backends #13097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,15 +556,15 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
}

// implementation of the 2D RoPE without adding a new op in ggml
// this is not efficient (use double the memory), but works on all backends
// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
static ggml_tensor * build_rope_2d(
ggml_cgraph * gf,
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * pos_h,
ggml_tensor * pos_w,
const float freq_base
) {
ggml_tensor * tmp;
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
const int64_t n_pos = cur->ne[2];
Expand All @@ -573,18 +573,23 @@ static ggml_tensor * build_rope_2d(
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
// first half of cur will use 1e-0, 1e-2 (even)
// second half of cur will use 1e-1, 1e-3 (odd)
//
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
// the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
// then for the second half, we use freq_scale to shift the inv_freq
// ^ why? replace (2i) with (2i+1) in the above equation
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);

// first half
ggml_tensor * first;
{
cur = ggml_rope_ext_inplace(
first = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
0);
first = ggml_rope_ext(
ctx0,
cur,
first,
pos_h, // positions
nullptr, // freq factors
n_dim/2, // n_dims
Expand All @@ -594,26 +599,27 @@ static ggml_tensor * build_rope_2d(
}

// second half
ggml_tensor * second;
{
tmp = ggml_view_3d(ctx0, cur,
second = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
n_dim/2 * ggml_element_size(cur));
tmp = ggml_rope_ext_inplace(
second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you suspect that ggml_rope is not implemented correctly for non-contiguous tensors, please add a test to test-backend-ops that shows the problem.

second = ggml_rope_ext(
ctx0,
tmp,
second,
pos_w, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
freq_scale_odd,
0.0f, 1.0f, 0.0f, 0.0f
);
// calculate inplace (modify cur directly)
ggml_build_forward_expand(gf, tmp);
}

cur = ggml_concat(ctx0, first, second, 0);
return cur;
}

Expand Down Expand Up @@ -682,13 +688,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);

Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));

struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);

K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));

struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
Expand Down Expand Up @@ -2804,10 +2810,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;

// TODO @ngxson : this is ugly, need to refactor later
bool support_dynamic_size = ctx->has_minicpmv_projector
|| ctx->has_qwen2vl_merger
|| ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;

const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
if (support_dynamic_size) {
image_size_width = imgs.entries[0]->nx;
image_size_height = imgs.entries[0]->ny;
}
Expand All @@ -2819,9 +2830,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima

{
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
float * data = (float *)malloc(ggml_nbytes(inp_raw));
std::vector<float> inp_data(ggml_nelements(inp_raw));
float * data = inp_data.data();

// layout of data (note: the channel dim is unrolled to better visualize the layout):
//
// ┌──W──┐
// │ H │ channel = R
// ├─────┤ │
// │ H │ channel = G
// ├─────┤ │
// │ H │ channel = B
// └─────┘ │
// ──────┘ x B

// TODO @ngxson : this whole code block is ugly, will need to be refactored
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
Expand All @@ -2836,17 +2858,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int n = nx * ny;

for (int b = 0; b < batch_size; b++) {
for (int k = 0; k < 3; k++) {
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
}
float * batch_entry = data + b * (3*n);
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
size_t base_src = 3*(y * nx + x); // idx of the first channel
size_t base_dst = y * nx + x; // idx of the first channel
batch_entry[ base_dst] = imgs.entries[b]->buf[base_src ];
batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
}
}
}
}
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
free(data);
}
if (ctx->has_minicpmv_projector) {
{
Expand Down
2 changes: 2 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2606,6 +2606,8 @@ struct test_rope : public test_case {
} else {
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}

// TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
}
ggml_set_name(out, "out");

Expand Down
Loading