Skip to content

Commit ee26ea7

Browse files
author
Yavor Ivanov
committed
sam : add the dense img pe in SAM state (#401)
1 parent b4fe0c4 commit ee26ea7

File tree

1 file changed

+92
-1
lines changed

1 file changed

+92
-1
lines changed

examples/sam/main.cpp

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ struct sam_state {
114114
struct ggml_tensor * embd_img;
115115
struct ggml_tensor * embd_prompt_sparse;
116116
struct ggml_tensor * embd_prompt_dense;
117+
struct ggml_tensor * pe_img_dense;
117118

118119
struct ggml_context * ctx;
119120

@@ -532,7 +533,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
532533

533534
// key + value memory
534535
{
535-
const auto & hparams = model.hparams;
536+
// const auto & hparams = model.hparams;
536537

537538
// TODO
538539
}
@@ -630,6 +631,88 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
630631
return true;
631632
}
632633

634+
bool sam_fill_dense_pe(
635+
const sam_model & model,
636+
sam_state & state,
637+
int n_threads) {
638+
const auto & hparams = model.hparams;
639+
const auto & enc = model.enc_prompt;
640+
641+
const int32_t n_img_embd = hparams.n_img_embd();
642+
const float n_img_embd_inv = 1.0f / n_img_embd;
643+
644+
static size_t buf_size = 256u*1024*1024;
645+
static void * buf = malloc(buf_size);
646+
647+
struct ggml_init_params params = {
648+
/*.mem_size =*/ buf_size,
649+
/*.mem_buffer =*/ buf,
650+
/*.no_alloc =*/ false,
651+
};
652+
653+
struct ggml_context * ctx0 = ggml_init(params);
654+
struct ggml_cgraph gf = {};
655+
656+
struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd);
657+
658+
{
659+
float * data = (float *) ggml_get_data(xy_embed_stacked);
660+
for (int i = 0; i < n_img_embd; ++i) {
661+
const int row = 2*i*n_img_embd;
662+
const float y_val = 2 * (i + 0.5f) * n_img_embd_inv - 1;
663+
for (int j = 0; j < n_img_embd; ++j) {
664+
const float x_val = 2 * (j + 0.5f) * n_img_embd_inv - 1;
665+
data[row + 2*j + 0] = x_val;
666+
data[row + 2*j + 1] = y_val;
667+
}
668+
}
669+
}
670+
671+
struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), xy_embed_stacked);
672+
673+
cur = ggml_scale(ctx0, cur, ggml_new_f32(ctx0, 2.0f*M_PI));
674+
675+
// concat
676+
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
677+
{
678+
struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
679+
struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);
680+
681+
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);
682+
683+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0)));
684+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1])));
685+
}
686+
687+
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
688+
689+
// TODO: avoid copy
690+
cur = ggml_cpy(ctx0, cur, state.pe_img_dense);
691+
692+
// run the computation
693+
ggml_set_name(cur, "check");
694+
ggml_build_forward_expand(&gf, cur);
695+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
696+
697+
// auto * t = ggml_get_tensor(ctx0, "check");
698+
// auto print_t_f32 = [&](struct ggml_tensor * t) {
699+
// float * data = (float *)t->data;
700+
// printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
701+
// for (int i = 0; i < 256; i++) {
702+
// printf("%f ", data[256*64*63 + 63*256 + i]);
703+
// }
704+
// printf("\n");
705+
// double sum = 0.0;
706+
// for (int i = 0; i < ggml_nelements(t); i++) {
707+
// sum += data[i];
708+
// }
709+
// printf("sum: %f\n", sum);
710+
// };
711+
// print_t_f32(t);
712+
713+
return true;
714+
}
715+
633716
bool sam_encode_image(
634717
const sam_model & model,
635718
sam_state & state,
@@ -1254,6 +1337,14 @@ int main(int argc, char ** argv) {
12541337

12551338
state.embd_prompt_dense = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
12561339
model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);
1340+
1341+
state.pe_img_dense = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
1342+
model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);
1343+
}
1344+
1345+
if (!sam_fill_dense_pe(model, state, params.n_threads)) {
1346+
fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__);
1347+
return 1;
12571348
}
12581349

12591350
if (!sam_encode_image(model, state, img1, params.n_threads)) {

0 commit comments

Comments
 (0)