diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 4c03529874924..4d7340a56bd0c 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -1,5 +1,6 @@ #include "ggml.h" #include "gguf.h" +#include "clip.h" #include "clip.h" @@ -202,23 +203,31 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, .. // cpp wrappers // +// wrapper for clip_image_size +struct clip_image_size_deleter { + void operator()(clip_image_size * val) { clip_image_size_free(val); } +}; +typedef std::unique_ptr clip_image_size_ptr; + +// wrapper for clip_image_u8 struct clip_image_u8_deleter { void operator()(clip_image_u8 * val) { clip_image_u8_free(val); } }; +typedef std::unique_ptr clip_image_u8_ptr; +// wrapper for clip_image_f32 struct clip_image_f32_deleter { void operator()(clip_image_f32 * val) { clip_image_f32_free(val); } }; +typedef std::unique_ptr clip_image_f32_ptr; -struct clip_image_f32_batch_deleter { - void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } +struct clip_image_u8_batch { + std::vector entries; }; -typedef std::unique_ptr clip_image_u8_ptr; -typedef std::unique_ptr clip_image_f32_ptr; -typedef std::unique_ptr clip_image_f32_batch_ptr; - -// TODO @ngxson : we're currently having a naming clash between struct clip_image_size and function clip_image_size() +struct clip_image_f32_batch { + std::vector entries; +}; // // common utils diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 710309edaecd6..a55b3f3835184 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -315,58 +315,47 @@ struct clip_ctx { bool use_gelu = false; bool use_silu = false; - struct gguf_context * ctx_gguf = nullptr; - struct ggml_context * ctx_data = nullptr; + gguf_context_ptr ctx_gguf; + ggml_context_ptr ctx_data; std::vector buf_compute_meta; std::vector backend_ptrs; std::vector backend_buft; - ggml_backend_t backend = nullptr; - ggml_backend_t backend_cpu = nullptr; - ggml_backend_buffer_t buf = nullptr; + ggml_backend_ptr backend; + ggml_backend_ptr backend_cpu; + ggml_backend_buffer_ptr buf; ggml_backend_sched_ptr sched; - struct clip_image_size * load_image_size = nullptr; + clip_image_size load_image_size; clip_ctx(clip_context_params & ctx_params) { - backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); - backend = ctx_params.use_gpu + backend_cpu = ggml_backend_ptr(ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr)); + backend = ggml_backend_ptr(ctx_params.use_gpu ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr) - : nullptr; + : nullptr); if (backend) { - LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend)); - backend_ptrs.push_back(backend); - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); + LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend.get())); + backend_ptrs.push_back(backend.get()); + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend.get())); } else { - backend = backend_cpu; + backend = std::move(backend_cpu); LOG_INF("%s: CLIP using CPU backend\n", __func__); } - backend_ptrs.push_back(backend_cpu); - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); + backend_ptrs.push_back(backend_cpu.get()); + backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu.get())); sched.reset( ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false) ); } - - ~clip_ctx() { - ggml_free(ctx_data); - gguf_free(ctx_gguf); - ggml_backend_buffer_free(buf); - ggml_backend_free(backend); - if (backend_cpu != backend) { - ggml_backend_free(backend_cpu); - } - clip_image_size_free(load_image_size); - } }; -static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) { +static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) { const auto & model = ctx->vision_model; const auto & hparams = model.hparams; @@ -382,7 +371,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im const int n_layer = hparams.n_layer; const float eps = hparams.eps; - GGML_ASSERT(imgs->size == 1); // batch_size == 1 + GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1 struct ggml_init_params params = { /*.mem_size =*/ ctx->buf_compute_meta.size(), @@ -390,7 +379,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im /*.no_alloc =*/ true, }; - struct ggml_context * ctx0 = ggml_init(params); + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); // input raw @@ -512,12 +503,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im // build the graph ggml_build_forward_expand(gf, embeddings); - ggml_free(ctx0); - return gf; } -static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { +static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; @@ -530,23 +519,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im int image_size_width = image_size; int image_size_height = image_size; if (ctx->has_minicpmv_projector) { - if (load_image_size == nullptr) { - load_image_size = clip_image_size_init(); - } - LOG_DBG("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height); - image_size_width = load_image_size->width; - image_size_height = load_image_size->height; + LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height); + image_size_width = load_image_size.width; + image_size_height = load_image_size.height; if (is_inf) { - image_size_width = imgs->data->nx; - image_size_height = imgs->data->ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } } else if (ctx->has_qwen2vl_merger) { // use the image's native resolution when image is avaible if (is_inf) { // if (imgs->data->nx && imgs->data->ny) { - image_size_width = imgs->data->nx; - image_size_height = imgs->data->ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } } const int patch_size = hparams.patch_size; @@ -561,7 +547,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; - const int batch_size = imgs->size; + const int batch_size = imgs.entries.size(); if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { GGML_ASSERT(batch_size == 1); @@ -573,7 +559,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im /*.no_alloc =*/ true, }; - struct ggml_context * ctx0 = ggml_init(params); + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); @@ -1061,7 +1049,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } } else { - GGML_ABORT("fatel error"); + GGML_ABORT("fatal error"); } } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { @@ -1081,12 +1069,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // build the graph ggml_build_forward_expand(gf, embeddings); - ggml_free(ctx0); - return gf; } -static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { +static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { return clip_image_build_graph_siglip(ctx, imgs); } else { @@ -1257,7 +1243,7 @@ struct clip_model_loader { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - ctx_clip.ctx_data = ggml_init(params); + ctx_clip.ctx_data.reset(ggml_init(params)); if (!ctx_clip.ctx_data) { throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__)); } @@ -1271,7 +1257,7 @@ struct clip_model_loader { if (cur) { tensors_to_load.push_back(cur); // add tensors to context - struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data, cur); + struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur); ggml_set_name(data_tensor, cur->name); cur = data_tensor; } @@ -1442,11 +1428,11 @@ struct clip_model_loader { } // alloc memory and offload data - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); - ctx_clip.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data, buft); - ggml_backend_buffer_set_usage(ctx_clip.buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend.get()); + ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); + ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); for (auto & t : tensors_to_load) { - struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data, t->name); + struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); const size_t offset = tensor_offset[t->name]; fin.seekg(offset, std::ios::beg); if (!fin) { @@ -1471,10 +1457,20 @@ struct clip_model_loader { void alloc_compute_meta() { ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); + + // create a fake batch clip_image_f32_batch batch; - batch.size = 1; - batch.data = nullptr; - ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, &batch, nullptr, false); + clip_image_f32_ptr img(clip_image_f32_init()); + clip_image_size image_size; + image_size.width = clip_get_image_size(&ctx_clip); + image_size.height = clip_get_image_size(&ctx_clip); + int n_patches = clip_get_image_size(&ctx_clip) / image_size.width; + img->nx = n_patches; + img->ny = n_patches; + img->buf.resize(n_patches * image_size.width * image_size.height * 3); + batch.entries.push_back(std::move(img)); + + ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) { ggml_backend_t backend = ctx_clip.backend_ptrs[i]; @@ -1575,11 +1571,11 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p } void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { - ctx_clip->load_image_size = load_image_size; + ctx_clip->load_image_size = *load_image_size; // copy } struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) { - return ctx_clip->load_image_size; + return &ctx_clip->load_image_size; } struct clip_image_size * clip_image_size_init() { @@ -1597,6 +1593,10 @@ struct clip_image_f32 * clip_image_f32_init() { return new clip_image_f32(); } +struct clip_image_f32_batch * clip_image_f32_batch_init() { + return new clip_image_f32_batch(); +} + unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) { if (nx) *nx = img->nx; if (ny) *ny = img->ny; @@ -1609,19 +1609,37 @@ void clip_image_size_free(struct clip_image_size * load_image_size) { } delete load_image_size; } -void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } -void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } -void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { - if (batch->size > 0) { - delete[] batch->data; - batch->size = 0; +void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; } +void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; } +void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; } +void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; } + +size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) { + return batch->entries.size(); +} + +size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return 0; + } + return batch->entries[idx]->nx; +} + +size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return 0; } + return batch->entries[idx]->ny; } -void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { - if (batch->size > 0) { - delete[] batch->data; - batch->size = 0; + +clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return nullptr; } + return batch->entries[idx].get(); } void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) { @@ -1695,14 +1713,15 @@ static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int ta } // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not -static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { - dst->nx = src->nx; - dst->ny = src->ny; - dst->buf.resize(src->buf.size()); +static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { + dst.nx = src.nx; + dst.ny = src.ny; + dst.buf.resize(src.buf.size()); - for (size_t i = 0; i < src->buf.size(); ++i) { + // TODO @ngxson : seems like this could be done more efficiently on cgraph + for (size_t i = 0; i < src.buf.size(); ++i) { int c = i % 3; // rgb - dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - mean[c]) / std[c]; + dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; } } @@ -1710,7 +1729,7 @@ inline int clip(int x, int lower, int upper) { return std::max(lower, std::min(x, upper)); } -static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) { +static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { const int nx = img.nx; const int ny = img.ny; @@ -1848,13 +1867,13 @@ static std::pair select_best_resolution(const std::pair & or return best_fit; } -static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { - std::vector patches; +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { + std::vector patches; int width = image.nx; int height = image.ny; for (int i = 0; i < height; i += patch_size) { for (int j = 0; j < width; j += patch_size) { - clip_image_u8 *patch = clip_image_u8_init(); + clip_image_u8_ptr patch(clip_image_u8_init()); patch->nx = std::min(patch_size, width - j); patch->ny = std::min(patch_size, height - i); patch->buf.resize(3 * patch->nx * patch->ny); @@ -1865,7 +1884,7 @@ static std::vector divide_to_patches_u8(const clip_image_u8 & im } } } - patches.push_back(patch); + patches.push_back(std::move(patch)); } } return patches; @@ -1946,7 +1965,7 @@ static std::pair uhd_best_grid(const int max_slice_nums, const int mul // -> https://arxiv.org/pdf/2403.11703 // -> https://github.com/thunlp/LLaVA-UHD // -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 -static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { +static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { const std::pair original_size={img->nx,img->ny}; const int original_width = img->nx; const int original_height = img->ny; @@ -1954,30 +1973,30 @@ static std::vector> uhd_slice_image(const clip_imag const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); const int multiple = fmin(ceil(ratio), max_slice_nums); - std::vector> images; + std::vector> images; LOG_DBG("%s: multiple %d\n", __func__, multiple); - images.push_back(std::vector()); + images.push_back(std::vector()); if (multiple <= 1) { auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); - clip_image_u8 * source_image = clip_image_u8_init(); + clip_image_u8_ptr source_image(clip_image_u8_init()); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.resize(best_size, Image.Resampling.BICUBIC) - images[images.size()-1].push_back(source_image); + images.back().push_back(std::move(source_image)); } else if (multiple > 1) { auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); - clip_image_u8 * source_image = clip_image_u8_init(); + clip_image_u8_ptr source_image(clip_image_u8_init()); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); - images[images.size()-1].push_back(source_image); + images.back().push_back(std::move(source_image)); std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); - clip_image_u8 * refine_image = clip_image_u8_init(); + clip_image_u8_ptr refine_image(clip_image_u8_init()); bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); @@ -1988,9 +2007,9 @@ static std::vector> uhd_slice_image(const clip_imag int grid_x = int(width / best_grid.first); int grid_y = int(height / best_grid.second); for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ - images.push_back(std::vector()); + images.push_back(std::vector()); for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ - clip_image_u8 * patch = clip_image_u8_init(); + clip_image_u8_ptr patch(clip_image_u8_init()); patch->nx = grid_x; patch->ny = grid_y; patch->buf.resize(3 * patch->nx * patch->ny); @@ -2003,10 +2022,9 @@ static std::vector> uhd_slice_image(const clip_imag patch->buf[j+2] = refine_image->buf[i+2]; } } - images[images.size()-1].push_back(patch); + images.back().push_back(std::move(patch)); } } - clip_image_u8_free(refine_image); } return images; } @@ -2014,8 +2032,8 @@ static std::vector> uhd_slice_image(const clip_imag int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { const int max_slice_nums=9; const int scale_resolution=448; - const int original_width = ctx_clip->load_image_size->width; - const int original_height = ctx_clip->load_image_size->height; + const int original_width = ctx_clip->load_image_size.width; + const int original_height = ctx_clip->load_image_size.height; const float log_ratio = log(1.0*original_width/original_height); const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); const int multiple = fmin(ceil(ratio), max_slice_nums); @@ -2025,64 +2043,44 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found -bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { +bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { - if(clip_is_minicpmv(ctx)){ + if (clip_is_minicpmv(ctx)) { int max_slice_nums = 9; - std::vector> imgs = uhd_slice_image(img, max_slice_nums); - res_imgs->size = 0; - for (size_t i = 0; i < imgs.size(); ++i){ - res_imgs->size += imgs[i].size(); - } - res_imgs->data = new clip_image_f32[res_imgs->size]; - int idx = 0; + std::vector> imgs = uhd_slice_image(img, max_slice_nums); for (size_t i = 0; i < imgs.size(); ++i) { for (size_t j = 0; j < imgs[i].size(); ++j) { LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); - clip_image_f32 * res = clip_image_f32_init(); - normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std); - res_imgs->data[idx++] = *res; - clip_image_f32_free(res); - } - } - for (size_t i = 0; i < imgs.size(); ++i) { - for (size_t j = 0; j < imgs[i].size(); ++j) { - if (imgs[i][j] != nullptr) { - clip_image_u8_free(imgs[i][j]); - } + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } } return true; } else if (ctx->has_qwen2vl_merger) { - clip_image_u8 * resized = clip_image_u8_init(); - auto patch_size = clip_patch_size(ctx) * 2; + clip_image_u8 resized; + auto patch_size = clip_get_patch_size(ctx) * 2; int nx = ceil((float)img->nx / patch_size) * patch_size; int ny = ceil((float)img->ny / patch_size) * patch_size; - bicubic_resize(*img, *resized, nx, ny); + bicubic_resize(*img, resized, nx, ny); - res_imgs->data = new clip_image_f32[1]; - // clip_image_f32 * res = clip_image_f32_init(); - normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + // clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std); // res_imgs->data[0] = *res; - res_imgs->size = 1; - - // clip_image_f32_free(res); - clip_image_u8_free(resized); + res_imgs->entries.push_back(std::move(img_f32)); return true; } if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - res_imgs->size = 1; - res_imgs->data = new clip_image_f32[res_imgs->size]; clip_image_u8 resized_image; int32_t sz=ctx->vision_model.hparams.image_size; bicubic_resize(*img, resized_image,sz,sz); - clip_image_f32 * res = clip_image_f32_init(); + clip_image_f32_ptr img_f32(clip_image_f32_init()); //clip_image_save_to_bmp(resized_image, "resized.bmp"); - normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std); - res_imgs->data[0] = *res; - clip_image_f32_free(res); + normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(img_f32)); return true; } @@ -2097,16 +2095,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli pad_to_square = false; } // free the previous res_imgs if any set - if (res_imgs->size > 0) { - clip_image_f32_batch_free(res_imgs); - } - res_imgs->data = nullptr; - res_imgs->size = 0; + res_imgs->entries.clear(); // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 - clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily + clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily if (pad_to_square && img->nx != img->ny) { int longer_side = std::max(img->nx, img->ny); temp->nx = longer_side; @@ -2149,28 +2143,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli // clip_image_u8_free(temp2); // } - std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - clip_image_u8 *image_original_resize = clip_image_u8_init(); + clip_image_u8_ptr image_original_resize(clip_image_u8_init()); // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - patches.insert(patches.begin(), image_original_resize); - // clip_image_f32_batch_init(patches.size()); - res_imgs->size = patches.size(); - res_imgs->data = new clip_image_f32[res_imgs->size]; - int num=0; - for (auto& patch : patches) { - normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std); - num++; - } - - for (size_t i = 0; i < patches.size(); i++) { - // LOG_DBG("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny); - clip_image_u8_free(patches[i]); + patches.insert(patches.begin(), std::move(image_original_resize)); + for (auto & patch : patches) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } - clip_image_u8_free(temp); - return true; } else { temp->nx = img->nx; @@ -2186,7 +2170,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli const int nx2 = ctx->vision_model.hparams.image_size; const int ny2 = ctx->vision_model.hparams.image_size; - clip_image_f32 * res = clip_image_f32_init(); + clip_image_f32_ptr res(clip_image_f32_init()); res->nx = nx2; res->ny = ny2; res->buf.resize(3 * nx2 * ny2); @@ -2238,7 +2222,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } } } - clip_image_u8_free(temp); // { // clip_image_u8 * temp2 = clip_image_u8_init(); @@ -2248,10 +2231,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli // } // res_imgs.push_back(res); - res_imgs->size = 1; - res_imgs->data = new clip_image_f32[res_imgs->size]; - res_imgs->data[0] = *res; - clip_image_f32_free(res); + res_imgs->entries.push_back(std::move(res)); return true; } @@ -2279,15 +2259,15 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float); } -int32_t clip_image_size(const struct clip_ctx * ctx) { +int32_t clip_get_image_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.image_size; } -int32_t clip_patch_size(const struct clip_ctx * ctx) { +int32_t clip_get_patch_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.patch_size; } -int32_t clip_hidden_size(const struct clip_ctx * ctx) { +int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.hidden_size; } @@ -2434,19 +2414,23 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 return false; } - clip_image_f32_batch imgs{}; - imgs.size = 1; - imgs.data = img; + clip_image_f32_batch imgs; + clip_image_f32_ptr img_copy(clip_image_f32_init()); + *img_copy = *img; + imgs.entries.push_back(std::move(img_copy)); + return clip_image_batch_encode(ctx, n_threads, &imgs, vec); } -bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { +bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) { + const clip_image_f32_batch & imgs = *imgs_c_ptr; + if (!ctx->has_vision_encoder) { LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } - int batch_size = imgs->size; + int batch_size = imgs.entries.size(); if (ctx->has_llava_projector) { GGML_ASSERT(batch_size == 1); // TODO: support multiple images } @@ -2473,25 +2457,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima int image_size_width = image_size; int image_size_height = image_size; if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) { - image_size_width = imgs->data[0].nx; - image_size_height = imgs->data[0].ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - if(ctx->load_image_size==nullptr){ - ctx->load_image_size= clip_image_size_init(); - } - const int pos_w = ctx->load_image_size->width/patch_size; - const int pos_h = ctx->load_image_size->height/patch_size; + const int pos_w = ctx->load_image_size.width / patch_size; + const int pos_h = ctx->load_image_size.height / patch_size; { struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); float * data = (float *)malloc(ggml_nbytes(inp_raw)); - for (size_t i = 0; i < imgs->size; i++) { - const int nx = imgs->data[i].nx; - const int ny = imgs->data[i].ny; + for (size_t i = 0; i < imgs.entries.size(); i++) { + const int nx = imgs.entries[i]->nx; + const int ny = imgs.entries[i]->ny; if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) { GGML_ASSERT(nx == image_size && ny == image_size); } @@ -2502,7 +2483,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (int k = 0; k < 3; k++) { for (int y = 0; y < ny; y++) { for (int x = 0; x < nx; x++) { - data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k]; + data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k]; } } } @@ -2629,7 +2610,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); + ggml_backend_cpu_set_n_threads(ctx->backend_cpu.get(), n_threads); auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); if (status != GGML_STATUS_SUCCESS) { @@ -2662,8 +2643,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i /* verbosity */ GGML_LOG_LEVEL_ERROR, }); - const auto & ctx_src = ctx_clip->ctx_gguf; - const auto & ctx_data = ctx_clip->ctx_data; + const auto & ctx_src = ctx_clip->ctx_gguf.get(); + const auto & ctx_data = ctx_clip->ctx_data.get(); auto * ctx_out = gguf_init_empty(); gguf_set_kv(ctx_out, ctx_src); diff --git a/examples/llava/clip.h b/examples/llava/clip.h index f61e0c0b2b3a7..cc133a58de3e8 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -30,15 +30,8 @@ struct clip_image_size { int height; }; -struct clip_image_u8_batch { - struct clip_image_u8 * data; - size_t size; -}; - -struct clip_image_f32_batch { - struct clip_image_f32 * data; - size_t size; -}; +struct clip_image_u8_batch; +struct clip_image_f32_batch; struct clip_context_params { bool use_gpu; @@ -55,9 +48,9 @@ CLIP_API void clip_free(struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w); -CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx); // TODO: should be enum, not string CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); @@ -73,9 +66,10 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); -CLIP_API struct clip_image_size * clip_image_size_init(); -CLIP_API struct clip_image_u8 * clip_image_u8_init (); -CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API struct clip_image_size * clip_image_size_init(); +CLIP_API struct clip_image_u8 * clip_image_u8_init (); +CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava // nx, ny are the output image dimensions CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny); @@ -86,6 +80,12 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); +// use for accessing underlay data of clip_image_f32_batch +CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size() +CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx +CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny +CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data + /** * Build image from pixels decoded by other libraries instead of stb_image.h for better performance. * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 518aad3f1f70b..03a22cbb4c205 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #if defined(LLAVA_LOG_OFF) # define LOG_INF(...) @@ -45,6 +46,17 @@ struct clip_image_grid_shape { int second; }; +// convenience cpp wrapper +struct clip_image_f32_batch_deleter { + void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } +}; +typedef std::unique_ptr clip_image_f32_batch_ptr; + +struct clip_image_size_deleter { + void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } +}; +typedef std::unique_ptr clip_image_size_ptr; + /** * Selects the best resolution from a list of possible resolutions based on the original size. * @@ -105,8 +117,8 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector struct ggml_context * ctx; } model; - const int32_t image_size = clip_image_size(ctx_clip); - const int32_t patch_size = clip_patch_size(ctx_clip); + const int32_t image_size = clip_get_image_size(ctx_clip); + const int32_t patch_size = clip_get_patch_size(ctx_clip); int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) @@ -246,12 +258,9 @@ static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - clip_image_f32_batch img_res_v; - img_res_v.size = 0; - img_res_v.data = nullptr; - if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) { + clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init()); + if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) { LOG_ERR("%s: unable to preprocess image\n", __func__); - delete[] img_res_v.data; return false; } @@ -259,66 +268,72 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); + const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get()); + if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) { std::vector image_embd_v; - image_embd_v.resize(img_res_v.size); - struct clip_image_size * load_image_size = clip_image_size_init(); + image_embd_v.resize(n_imgs); + clip_image_size load_image_size; - for (size_t i = 0; i < img_res_v.size; i++) { + for (size_t i = 0; i < n_imgs; i++) { const int64_t t_img_enc_step_start_us = ggml_time_us(); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); - int patch_size=14; - load_image_size->width = img_res_v.data[i].nx; - load_image_size->height = img_res_v.data[i].ny; - clip_add_load_image_size(ctx_clip, load_image_size); + int nx = clip_image_f32_batch_nx(img_res_v.get(), i); + int ny = clip_image_f32_batch_ny(img_res_v.get(), i); + image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny)); + int patch_size = 14; + load_image_size.width = nx; + load_image_size.height = ny; + clip_add_load_image_size(ctx_clip, &load_image_size); bool encoded = false; + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); if (clip_is_qwen2vl(ctx_clip)) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); + encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); } else { - encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]); } if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); return false; } const int64_t t_img_enc_steop_batch_us = ggml_time_us(); - LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); + LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); int n_img_pos_out = 0; for (size_t i = 0; i < image_embd_v.size(); i++) { + int nx = clip_image_f32_batch_nx(img_res_v.get(), i); + int ny = clip_image_f32_batch_ny(img_res_v.get(), i); + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); std::memcpy( image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], - clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); - n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]); + clip_embd_nbytes_by_img(ctx_clip, nx, ny)); + n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res); } *n_img_pos = n_img_pos_out; for (size_t i = 0; i < image_embd_v.size(); i++) { free(image_embd_v[i]); } image_embd_v.clear(); - load_image_size->width = img->nx; - load_image_size->height = img->ny; - clip_add_load_image_size(ctx_clip, load_image_size); - LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); - delete[] img_res_v.data; - img_res_v.size = 0; - img_res_v.data = nullptr; + load_image_size.width = img->nx; + load_image_size.height = img->ny; + clip_add_load_image_size(ctx_clip, &load_image_size); + LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height); } else if (clip_is_glm(ctx_clip)){ struct clip_image_size * load_image_size = clip_image_size_init(); - load_image_size->width = img_res_v.data[0].nx; - load_image_size->height = img_res_v.data[0].ny; + load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0); + load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0); clip_add_load_image_size(ctx_clip, load_image_size); - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); - int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2); + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); + int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2); *n_img_pos = (pos * pos + 2); if (!encoded){ LOG_ERR("Unable to encode image \n"); @@ -328,8 +343,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 - delete[] img_res_v.data; + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096 if (!encoded) { LOG_ERR("Unable to encode image\n"); @@ -340,17 +355,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli // spatial_unpad llava-1.6 type embedding // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working std::vector image_embd_v; - image_embd_v.resize(img_res_v.size); - for (size_t i = 0; i < img_res_v.size; i++) { + image_embd_v.resize(n_imgs); + for (size_t i = 0; i < n_imgs; i++) { + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); return false; } } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); const int32_t * image_grid = clip_image_grid(ctx_clip); const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); @@ -360,12 +376,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } - // free all img_res_v - not needed anymore - delete[] img_res_v.data; - img_res_v.size = 0; - img_res_v.data = nullptr; - - const int32_t image_size = clip_image_size(ctx_clip); + const int32_t image_size = clip_get_image_size(ctx_clip); struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 58503d0b22c33..114c274bc1250 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -41,14 +41,14 @@ struct mtmd_context { }; struct mtmd_image_tokens_data { - clip_image_f32_batch_ptr batch_f32; // preprocessed image patches + clip_image_f32_batch batch_f32; // preprocessed image patches }; struct mtmd_image_tokens { uint32_t nx; // number of tokens in x direction uint32_t ny; // number of tokens in y direction uint32_t n_tokens() const { return nx * ny; } - clip_image_f32_batch_ptr batch_f32; // preprocessed image patches + clip_image_f32_batch batch_f32; // preprocessed image patches }; mtmd_context * mtmd_init_from_file(const char * mmproj_fname, @@ -141,8 +141,8 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3); // preprocess image - clip_image_f32_batch_ptr batch_f32(new clip_image_f32_batch); - bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), batch_f32.get()); + clip_image_f32_batch batch_f32; + bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32); if (!ok) { LOG_ERR("Unable to preprocess image\n"); return nullptr; @@ -181,7 +181,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) bool ok = clip_image_batch_encode( ctx->ctx_clip, ctx->n_threads, - image_tokens->batch_f32.get(), + &image_tokens->batch_f32, ctx->image_embd_v.data()); return ok ? 0 : 1; }