diff --git a/CMakeLists.txt b/CMakeLists.txt index 50a6538..959a78b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ endif() # general option(CLIP_STATIC "CLIP: static link libraries" OFF) -option(CLIP_BUILD_TEST "CLIP: build tests" ${CLIP_STANDALONE}) +option(CLIP_BUILD_TESTS "CLIP: build tests" ${CLIP_STANDALONE}) option(CLIP_BUILD_EXAMPLES "CLIP: build examples" ${CLIP_STANDALONE}) option(CLIP_BUILD_IMAGE_SEARCH "CLIP: build image-search" OFF) option(CLIP_NATIVE "CLIP: enable -march=native flag" ON) diff --git a/README.md b/README.md index 74fb881..7fc927e 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,10 @@ This repo is aimed at powering useful applications based on such models on compu clip.cpp also has a short startup time compared to large ML frameworks, which makes it suitable for serverless deployments where the cold start is an issue. +## Hot topics +- 07/12/2023: Batch inference support for image encoding. +- 07/11/2023: Semantic image search [example](examples/image-search/README.md) directly in C++. + ## Note about image preprocessing PIL uses a two-pass convolutions-based bicubic interpolation in resizing with antialiasing applied. In Pytorch, antialiasing is optional. It needs some extra attention to implement this preprocessing logic that matches their results numerically. However, I found that linear interpolation is also good enough for both comparison of different embeddings from this implementation and also comparison of an embedding from this implementation and another one from Transformers. So let's use it until we craft a proper bicubic interpolation. diff --git a/clip.cpp b/clip.cpp index 37fcced..b77cbba 100644 --- a/clip.cpp +++ b/clip.cpp @@ -5,6 +5,8 @@ #include #include #include +#include + #include "ggml/ggml.h" #include "clip.h" @@ -23,7 +25,7 @@ size_t get_mem_req_by_size(const size_t n_tensors, const int n_image_positions) case 397: // base if (n_image_positions == 50) // patch size = 32 { - return 8 * mb; + return 12 * mb; } else // patch size = 16 { @@ -54,7 +56,7 @@ size_t get_scr_buf_req_by_size(const size_t n_tensors, const int n_positions) case 397: if (n_positions <= 50) { - return 16 * mb; + return 32 * mb; } else { @@ -252,6 +254,77 @@ bool clip_image_preprocess(const clip_ctx *ctx, const clip_image_u8 *img, clip_i return true; } +// Structure to hold the image data as an input to function to be executed for thread +typedef struct +{ + const clip_image_u8 *input; + clip_image_f32 *resized; + const clip_ctx *ctx; +} ImageData; + +// Function to preprocess a single image in a thread +void *preprocess_image(void *arg) +{ + ImageData *imageData = static_cast(arg); + const clip_image_u8 *input = imageData->input; + clip_image_f32 *resized = imageData->resized; + const clip_ctx *ctx = imageData->ctx; + + // Call the original preprocess function on the image + clip_image_preprocess(ctx, input, resized); + + pthread_exit(NULL); +} + +// Function to batch-preprocess multiple images i +void clip_image_batch_preprocess(const clip_ctx *ctx, const int n_threads, const std::vector &img_inputs, std::vector &imgs_resized) +{ + GGML_ASSERT(img_inputs.size() == imgs_resized.size()); + int num_threads = std::min(n_threads, static_cast(img_inputs.size())); + int i, t; + + // Divide the images among the threads + int images_per_thread = img_inputs.size() / num_threads; + + if (num_threads == 1) + { + // Single-threaded case + for (i = 0; i < img_inputs.size(); i++) + { + clip_image_preprocess(ctx, &img_inputs[i], &imgs_resized[i]); + } + } + else + { + // Multi-threaded case + + std::vector threads(num_threads); + std::vector imageData(img_inputs.size()); + + for (t = 0; t < num_threads; t++) + { + int start_index = t * images_per_thread; + int end_index = (t == num_threads - 1) ? img_inputs.size() : start_index + images_per_thread; + + // Create ImageData for each thread + for (i = start_index; i < end_index; i++) + { + imageData[i].input = &img_inputs[i]; + imageData[i].resized = &imgs_resized[i]; + imageData[i].ctx = ctx; + } + + // Create a thread for each batch of images + pthread_create(&threads[t], NULL, preprocess_image, static_cast(&imageData[start_index])); + } + + // Wait for all threads to finish + for (t = 0; t < num_threads; t++) + { + pthread_join(threads[t], NULL); + } + } +} struct clip_ctx *clip_model_load(const char *fname, const int verbosity = 1) { @@ -840,7 +913,6 @@ bool clip_text_encode( struct ggml_context *ctx0 = ggml_init(params); struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; static size_t scr0_size = get_scr_buf_req_by_size(ctx->text_model.tensors.size() + ctx->vision_model.tensors.size(), N); static void *scr0 = malloc(scr0_size); @@ -991,7 +1063,7 @@ bool clip_text_encode( // run the computation ggml_build_forward_expand(&gf, embeddings); - ggml_graph_compute(ctx0, &gf); + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); // print #ifdef CLIP_DEBUG @@ -1058,6 +1130,17 @@ bool clip_image_encode( int n_threads, const clip_image_f32 &img, float *vec) +{ + std::vector imgs; + imgs.push_back(img); + return clip_image_batch_encode(ctx, n_threads, imgs, vec); +} + +bool clip_image_batch_encode( + const clip_ctx *ctx, + int n_threads, + const std::vector &imgs, + float *vec) { const auto &model = ctx->vision_model; const auto &hparams = model.hparams; @@ -1072,6 +1155,7 @@ bool clip_image_encode( const int n_layer = hparams.n_layer; const int n_intermediate = hparams.n_intermediate; const int projection_dim = hparams.projection_dim; + int batch_size = imgs.size(); auto &buf_compute = ctx->buf_compute; @@ -1083,43 +1167,52 @@ bool clip_image_encode( struct ggml_context *ctx0 = ggml_init(params); struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; static size_t scr0_size = get_scr_buf_req_by_size(ctx->text_model.tensors.size() + ctx->vision_model.tensors.size(), num_positions); static void *scr0 = malloc(scr0_size); - struct ggml_tensor *inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, 1); + struct ggml_tensor *inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); { - float *data = (float *)ggml_get_data(inp); + float *data = (float *)ggml_get_data(inp_raw); - const int nx = img.nx; - const int ny = img.ny; - const int n = nx * ny; + for (int b = 0; b < imgs.size(); b++) + { + const int nx = imgs[b].nx; + const int ny = imgs[b].ny; + GGML_ASSERT(nx == image_size && ny == image_size); - GGML_ASSERT(nx == image_size && ny == image_size); + const int n = nx * ny; - for (int k = 0; k < 3; k++) - { - for (int y = 0; y < ny; y++) + for (int b = 0; b < batch_size; b++) { - for (int x = 0; x < nx; x++) + for (int k = 0; k < 3; k++) { - data[k * n + y * nx + x] = img.data[3 * (y * nx + x) + k]; + for (int y = 0; y < ny; y++) + { + for (int x = 0; x < nx; x++) + { + data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].data[3 * (y * nx + x) + k]; + } + } } } } } - inp = ggml_conv_2d_sk_p0(ctx0, model.patch_embeddings, inp); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + struct ggml_tensor *inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); // concat class_embeddings and patch_embeddings - struct ggml_tensor *embeddings = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_size, num_positions); + struct ggml_tensor *embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_zero(embeddings); - embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); - embeddings = ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], ggml_element_size(model.class_embedding) * hidden_size); + struct ggml_tensor *temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size); + + embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); for (int i = 0; i < num_positions; i++) @@ -1127,7 +1220,7 @@ bool clip_image_encode( ggml_set_i32_1d(positions, i, i); } - embeddings = ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); + embeddings = ggml_add(ctx0, embeddings, ggml_repeat(ctx0, ggml_get_rows(ctx0, model.position_embeddings, positions), embeddings)); // pre-layernorm { @@ -1145,6 +1238,8 @@ bool clip_image_encode( { struct ggml_tensor *cur = embeddings; // embeddings = residual, cur = hidden_states + const size_t nb_q_w = model.layers[il].q_w->nb[0]; + ggml_set_scratch(ctx0, {0, scr0_size, scr0}); // layernorm1 @@ -1160,44 +1255,48 @@ bool clip_image_encode( // self-attention { + struct ggml_tensor *Q = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), - ggml_mul_mat(ctx0, model.layers[il].q_w, cur)); + ggml_mul_mat(ctx0, model.layers[il].q_w, + cur)); Q = ggml_scale_inplace(ctx0, Q, ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head))); - Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, 1); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); - struct ggml_tensor *K = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur), - ggml_mul_mat(ctx0, model.layers[il].k_w, cur)); + struct ggml_tensor *K = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur), + ggml_mul_mat(ctx0, model.layers[il].k_w, + cur)); - K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, 1); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); - struct ggml_tensor *V = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur), - ggml_mul_mat(ctx0, model.layers[il].v_w, cur)); - V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, 1); + struct ggml_tensor *V = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur), + ggml_mul_mat(ctx0, model.layers[il].v_w, + cur)); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); KQ = ggml_soft_max_inplace(ctx0, KQ); struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, 1); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_cont(ctx0, ggml_permute(ctx0, KQV, 0, 2, 1, 3)); cur = ggml_cpy(ctx0, KQV, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_size, num_positions)); + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size)); } // attention output cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].o_b, cur), - ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); + ggml_mul_mat(ctx0, model.layers[il].o_w, + cur)); // re-add the layer input, e.g., residual cur = ggml_add(ctx0, cur, embeddings); @@ -1236,14 +1335,17 @@ bool clip_image_encode( // residual 2 cur = ggml_add(ctx0, embeddings, cur); - // ggml_set_name(cur, "check"); embeddings = cur; } // get the output of cls token, e.g., 0th index - struct ggml_tensor *cls = ggml_new_i32(ctx0, 0); - embeddings = ggml_get_rows(ctx0, embeddings, cls); + struct ggml_tensor *cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size); + for (int b = 0; b < batch_size; b++) + { + ggml_set_i32_1d(cls, b, b * num_positions); + } + embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls); // post-layernorm { @@ -1262,15 +1364,21 @@ bool clip_image_encode( embeddings = ggml_mul_mat(ctx0, model.projection, embeddings); // normalize output embeddings - ggml_tensor *length = ggml_sqrt(ctx0, - ggml_sum(ctx0, ggml_sqr(ctx0, embeddings))); - embeddings = ggml_scale_inplace(ctx0, embeddings, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); + struct ggml_tensor *output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size); - ggml_set_name(embeddings, "check"); + for (int b = 0; b < batch_size; b++) + { + struct ggml_tensor *embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b)); + ggml_tensor *length = ggml_sqrt(ctx0, + ggml_sum(ctx0, ggml_sqr(ctx0, embedding))); + embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); + output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding)); + } + ggml_set_name(output, "check"); // run the computation - ggml_build_forward_expand(&gf, embeddings); - ggml_graph_compute(ctx0, &gf); + ggml_build_forward_expand(&gf, output); + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); // print #ifdef CLIP_DEBUG @@ -1313,6 +1421,7 @@ bool clip_image_encode( }; auto *t = ggml_get_tensor(ctx0, "check"); + // auto t = inp_raw; if (t->type == GGML_TYPE_F32) { print_t_f32(t); @@ -1326,7 +1435,7 @@ bool clip_image_encode( printf("used_mem = %zu\n", ggml_used_mem(ctx0)); #endif - memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim); + memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size); ggml_free(ctx0); diff --git a/clip.h b/clip.h index 7978019..bc2b53d 100644 --- a/clip.h +++ b/clip.h @@ -200,6 +200,7 @@ std::vector clip_tokenize(const clip_ctx *ctx, const std::string bool clip_image_load_from_file(const std::string &fname, clip_image_u8 &img); bool clip_image_preprocess(const clip_ctx *ctx, const clip_image_u8 *img, clip_image_f32 *res); +void clip_image_batch_preprocess(const clip_ctx *ctx, const int n_threads, const std::vector &img_inputs, std::vector &img_resized); bool clip_text_encode( const clip_ctx *ctx, @@ -219,8 +220,11 @@ bool clip_compare_text_and_image(clip_ctx *ctx, int n_threads, std::string &text float clip_similarity_score(float *vec1, float *vec2, int vec_dim); bool softmax_with_sorting(float *arr, int length, float *sorted_scores, int *indices); -// utils for debugging -void write_floats_to_file(float *array, int size, char *filename); +bool clip_image_batch_encode( + const clip_ctx *ctx, + int n_threads, + const std::vector &imgs, + float *vec); // #ifdef __cplusplus // } diff --git a/ggml b/ggml index 93b94a2..5621652 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 93b94a2d41e880cb2abfb708535d5b04ad05b7a5 +Subproject commit 56216523fa1df0c0bc36201dbecd8e0a01668d91 diff --git a/tests/benchmark.cpp b/tests/benchmark.cpp index 0b61073..f23da4c 100644 --- a/tests/benchmark.cpp +++ b/tests/benchmark.cpp @@ -44,7 +44,7 @@ int main(int argc, char **argv) return 1; } - fprintf(fout, "%s: %d directories found in %s\n\n", __func__, n_labels, dir_path.c_str()); + fprintf(fout, "%s: %zu directories found in %s\n\n", __func__, n_labels, dir_path.c_str()); auto ctx = clip_model_load(model_path.c_str(), 2); if (!ctx) @@ -53,14 +53,12 @@ int main(int argc, char **argv) return 1; } + const size_t batch_size = 4; + const size_t n_threads = 4; + const int vec_dim = ctx->text_model.hparams.projection_dim; - // allocate memory for text vectors - float *txt_vecs = (float *)malloc(n_labels * vec_dim * sizeof(float)); - if (!txt_vecs) - { - printf("%s: Could not allocate memory for %d vectors of %d dimensions\n", __func__, n_labels, vec_dim); - } + float txt_vecs[n_labels * vec_dim]; ggml_time_init(); @@ -73,7 +71,7 @@ int main(int argc, char **argv) for (const auto &entry : result) { auto tokens = clip_tokenize(ctx, entry.first); - if (!clip_text_encode(ctx, 4, tokens, txt_vecs + label_idx * vec_dim)) + if (!clip_text_encode(ctx, n_threads, tokens, txt_vecs + label_idx * vec_dim)) { printf("%s: Could not encode the label at index %d: %s\n", __func__, label_idx, entry.first.c_str()); return 1; @@ -88,12 +86,13 @@ int main(int argc, char **argv) int n_total_items = 0; // total number of images processed float total_acc1_score = 0.0f; // total accuracy at 1 for the intire dataset float total_acc5_score = 0.0f; // total accuracy at 5 in intitre dataset - float img_vec[vec_dim]; + float img_vecs[vec_dim * batch_size]; + float similarities[n_labels]; float sorted_scores[n_labels]; int indices[n_labels]; - clip_image_u8 img; - clip_image_f32 img_res; + std::vector img_inputs(batch_size); + std::vector imgs_resized(batch_size); // print table headers fprintf(fout, "| class name | acc@1 | acc@5 |\n"); @@ -107,56 +106,59 @@ int main(int argc, char **argv) int n_acc1 = 0; int n_acc5 = 0; - int64_t t_start_encode_images = ggml_time_us(); + size_t n_batched = (entry.second.size() / batch_size) * batch_size; - for (auto &file_path : entry.second) + for (size_t i = 0; i < n_batched; i += batch_size) { - if (!clip_image_load_from_file(file_path, img)) + for (size_t ib = i; ib < i + batch_size; ib++) { - printf("%s: cannot load file from %s\n", __func__, file_path.c_str()); - return 1; - } + std::string file_path = entry.second[ib]; - if (!clip_image_preprocess(ctx, &img, &img_res)) - { - printf("%s: cannot preprocess image loaded from %s\n", __func__, file_path.c_str()); - return 1; + if (!clip_image_load_from_file(file_path, img_inputs[ib % batch_size])) + { + printf("%s: cannot load file from %s\n", __func__, file_path.c_str()); + return 1; + } } - clip_image_encode(ctx, 4, img_res, img_vec); - for (size_t i = 0; i < n_labels; i++) - { - similarities[i] = clip_similarity_score(img_vec, txt_vecs + i * vec_dim, vec_dim); - } + clip_image_batch_preprocess(ctx, n_threads, img_inputs, imgs_resized); - softmax_with_sorting(similarities, n_labels, sorted_scores, indices); - for (int j = 0; j < 5; j++) + clip_image_batch_encode(ctx, n_threads, imgs_resized, img_vecs); + + for (size_t b = 0; b < batch_size; b++) { - if (j == 0 && indices[j] == label_idx) + for (size_t j = 0; j < n_labels; j++) { - n_acc1 += 1; - n_acc5 += 1; - break; + similarities[j] = clip_similarity_score(img_vecs + b * vec_dim, txt_vecs + j * vec_dim, vec_dim); } - else if (indices[j] == label_idx) + softmax_with_sorting(similarities, n_labels, sorted_scores, indices); + + for (int k = 0; k < 5; k++) { - n_acc5 += 1; - break; + if (k == 0 && indices[k] == label_idx) + { + n_acc1 += 1; + n_acc5 += 1; + break; + } + else if (indices[k] == label_idx) + { + n_acc5 += 1; + break; + } } - } - n_items += 1; - n_total_items += 1; + n_items += 1; + n_total_items += 1; + } } float acc1_score = (float)n_acc1 / n_items; float acc5_score = (float)n_acc5 / n_items; total_acc1_score += acc1_score; total_acc5_score += acc5_score; - // printf("%s: acc@1 = %2.4f - acc@5 = %2.4f\n", entry.first.c_str(), acc1_score, acc5_score); fprintf(fout, "| %-*s ", 20, entry.first.c_str()); fprintf(fout, "| %2.4f | %2.4f |\n", acc1_score, acc5_score); - label_idx += 1; }