Skip to content

mtmd : add methods to access mtmd_image_tokens #12906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions examples/llava/gemma3-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,19 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector
text.text = formatted_chat.prompt;
text.add_special = add_bos;
text.parse_special = true;
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
if (chunks == nullptr) {
LOG_ERR("Unable to tokenize prompt\n");
mtmd_input_chunks chunks;
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
if (res != 0) {
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
return 1;
}

if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
LOG_ERR("Unable to eval prompt\n");
return 1;
}

ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
ctx.n_past += mtmd_helper_get_n_tokens(chunks);

return 0;
}
Expand Down
88 changes: 60 additions & 28 deletions examples/llava/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct mtmd_context {
struct clip_ctx * ctx_clip;
const struct llama_model * text_model;
std::vector<float> image_embd_v; // image embedding vector

bool print_timings;
int n_threads;
std::string image_marker;
Expand All @@ -24,7 +25,11 @@ struct mtmd_context {

mtmd_context(const char * mmproj_fname,
const llama_model * text_model,
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
const mtmd_context_params & ctx_params) :
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
image_marker (ctx_params.image_marker)
{
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
Expand All @@ -49,6 +54,7 @@ struct mtmd_image_tokens {
uint32_t ny; // number of tokens in y direction
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
};

mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
Expand Down Expand Up @@ -88,10 +94,10 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
return result;
}

mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
const mtmd_input_text & text,
const std::vector<mtmd_bitmap> & bitmaps) {
mtmd_input_chunks * output = new mtmd_input_chunks;
int32_t mtmd_tokenize(mtmd_context * ctx,
std::vector<mtmd_input_chunk> & output,
const mtmd_input_text & text,
const std::vector<mtmd_bitmap> & bitmaps) {
auto vocab = llama_model_get_vocab(ctx->text_model);

std::string prompt_modified(text.text);
Expand All @@ -105,9 +111,9 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
}

std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
output->clear();
output->reserve(parts.size());
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
output.clear();
output.reserve(parts.size());

size_t i_img = 0;

Expand All @@ -123,14 +129,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
std::move(tokens),
{},
};
output->emplace_back(std::move(chunk));
output.emplace_back(std::move(chunk));

if (&parts.back() != &part) {
Copy link
Member

@ggerganov ggerganov Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure, but I think this logic does not handle the case where the text ends with an image marker:

<some_text><image_marker>

For Gemma3 this will not happen because we wrap the image marker with text on both sides, but maybe for other models it could happen? If it cannot happen for sure, then this if should become assert.

Copy link
Collaborator Author

@ngxson ngxson Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current logic will produce an empty text chunk in the end if the image marker is placed at the end on the input prompt. This is an effect from string_split_str

For example, this code:

auto test = string_split_str("123aa456aa", "aa");
for (auto & p : test) printf("'%s'\n", p.c_str());

Will output:

'123'
'456'
''

I think having an empty chunk in the end is expected for now, but I should document it better.

If we don't want this empty chunk, the proper way is to stop using string_split_str and to write our own code to do string matching / splitting.

In reality, this will almost never happen because user always input a prompt with a generation prefix, something like <s>user\nwhat do you see?<image></s><s>assistant\n

Copy link
Collaborator Author

@ngxson ngxson Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed one line of code:

auto tokens = mtmd_tokenize_text_internal(...);
if (tokens.empty()) {
    continue;
}

So that means there is no empty chunk being added, the case where image marker placed in the end is correctly handled. This also handles the case where 2 image markers are place one next to the other.

(This piece of code was firstly introduced from my first attempt to refactor vision API, so yeah it's quite hacky)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will refactor this function later on. Could you review the rest of this PR? Thanks!!

// add image token to middle of 2 parts

if (i_img >= bitmaps.size()) {
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
return nullptr;
return 1;
}

// shim layer
Expand All @@ -145,34 +151,48 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
if (!ok) {
LOG_ERR("Unable to preprocess image\n");
return nullptr;
return 2;
}

mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
image_tokens->ny = 1; // TODO
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmaps[i_img].id; // optional

mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
image_tokens,
std::move(image_tokens),
};
output->emplace_back(std::move(chunk));
output.emplace_back(std::move(chunk));
i_img++;
}
}

return output;
return 0;
}

void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
for (auto & chunk : *chunks) {
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
delete chunk.tokens_image;
}
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
}
delete chunks;
}

size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
return image_tokens->n_tokens();
}

size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
return image_tokens->nx;
}

size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
return image_tokens->ny;
}

std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
return image_tokens->id;
}

int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
Expand All @@ -190,9 +210,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
return ctx->image_embd_v.data();
}

size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
size_t n_tokens = 0;
for (auto & chunk : *chunks) {
for (auto & chunk : chunks) {
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
n_tokens += chunk.tokens_text.size();
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
Expand Down Expand Up @@ -241,16 +261,16 @@ struct decode_embd_batch {

int32_t mtmd_helper_eval(mtmd_context * ctx,
llama_context * lctx,
mtmd_input_chunks * chunks,
mtmd_input_chunks & chunks,
llama_pos pos0,
llama_seq_id seq_id,
int32_t n_batch) {
int32_t ret;
llama_pos n_past = pos0;
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);

for (auto & chunk : *chunks) {
bool is_last = &chunk == &chunks->back();
for (auto & chunk : chunks) {
bool is_last = &chunk == &chunks.back();
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
// TODO @ngxson : may need to split into smaller batches
text_batch.n_tokens = chunk.tokens_text.size();
Expand Down Expand Up @@ -279,7 +299,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
if (ctx->print_timings) {
LOG_INF("encoding image...\n");
}
ret = mtmd_encode(ctx, chunk.tokens_image);
ret = mtmd_encode(ctx, chunk.tokens_image.get());
if (ret != 0) {
LOG_ERR("failed to encode image\n");
llama_batch_free(text_batch);
Expand All @@ -289,7 +309,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
}

int32_t n_tokens = chunk.tokens_image->n_tokens();
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
float * embd = mtmd_get_output_embd(ctx);
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
int64_t t1 = ggml_time_ms();
Expand Down Expand Up @@ -339,3 +359,15 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
return 0;
}

bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
return true;
}
return false;
}

void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
mtmd_image_tokens_free(val);
}
37 changes: 26 additions & 11 deletions examples/llava/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ struct mtmd_bitmap {
uint32_t nx;
uint32_t ny;
std::vector<unsigned char> data;
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
};

struct mtmd_image_tokens_deleter {
void operator()(mtmd_image_tokens * val); // forward declaration
};
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;

struct mtmd_input_chunk {
mtmd_input_chunk_type type;
std::vector<llama_token> tokens_text;
mtmd_image_tokens * tokens_image = nullptr;
mtmd_image_tokens_ptr tokens_image;
};

using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
Expand Down Expand Up @@ -82,12 +88,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
// 3. "<end_of_image>\ndescribe it in detail."
// number of bitmaps must be equal to the number of image markers in the prompt
// this function is thread-safe (shared ctx)
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
// return values:
// 0 on success
// 1 on number of images not matching the number of markers
// 2 on image preprocessing error
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
std::vector<mtmd_input_chunk> & output,
const mtmd_input_text & text,
const std::vector<mtmd_bitmap> & bitmaps);

// free image chunk data
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
// access mtmd_image_tokens
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);

// returns 0 on success
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
Expand All @@ -96,12 +111,17 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
// get output embeddings from the last encode pass
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);

// whether we need to set non-causal mask before llama_decode
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);



//
// helper functions (can be implemented based on other functions)
//

// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);

// helper function that automatically:
// 1. run llama_decode() on text chunks
Expand All @@ -110,7 +130,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
// otherwise, returns 0 on success
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
llama_context * lctx,
mtmd_input_chunks * chunks,
mtmd_input_chunks & chunks,
llama_pos pos0,
llama_seq_id seq_id,
int32_t n_batch);
Expand All @@ -132,11 +152,6 @@ struct mtmd_context_deleter {
};
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;

struct mtmd_input_chunks_deleter {
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
};
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;

#else

static_assert(false && "C header is not yet supported by this library");
Expand Down
Loading