@@ -172,6 +172,7 @@ struct clip_hparams {
172
172
std::unordered_set<int32_t > vision_feature_layer;
173
173
int32_t attn_window_size = 0 ;
174
174
int32_t n_wa_pattern = 0 ;
175
+ int32_t spatial_merge_size = 0 ;
175
176
};
176
177
177
178
struct clip_layer {
@@ -232,6 +233,7 @@ struct clip_vision_model {
232
233
struct ggml_tensor * projection;
233
234
234
235
// LLaVA projection
236
+ struct ggml_tensor * mm_input_norm_w = nullptr ;
235
237
struct ggml_tensor * mm_0_w = nullptr ;
236
238
struct ggml_tensor * mm_0_b = nullptr ;
237
239
struct ggml_tensor * mm_2_w = nullptr ;
@@ -311,6 +313,7 @@ struct clip_vision_model {
311
313
312
314
// pixtral
313
315
struct ggml_tensor * token_embd_img_break = nullptr ;
316
+ struct ggml_tensor * mm_patch_merger_w = nullptr ;
314
317
};
315
318
316
319
struct clip_ctx {
@@ -637,6 +640,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
637
640
const int d_head = hidden_size / n_head;
638
641
const int n_layer = hparams.n_layer ;
639
642
const float eps = hparams.eps ;
643
+ const int n_merge = hparams.spatial_merge_size ;
640
644
641
645
struct ggml_init_params params = {
642
646
/* .mem_size =*/ ctx->buf_compute_meta .size (),
@@ -721,7 +725,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
721
725
{
722
726
ggml_tensor * gate_proj = ggml_mul_mat (ctx0, model.layers [il].ff_gate_w , cur);
723
727
ggml_tensor * up_proj = ggml_mul_mat (ctx0, model.layers [il].ff_up_w , cur);
724
- gate_proj = ggml_silu (ctx0, gate_proj); // pixtral uses silu
728
+ if (ctx->use_silu ) {
729
+ gate_proj = ggml_silu (ctx0, gate_proj);
730
+ } else if (ctx->use_gelu ) {
731
+ gate_proj = ggml_gelu (ctx0, gate_proj);
732
+ } else {
733
+ GGML_ABORT (" Pixtral: Unsupported activation" );
734
+ }
725
735
cur = ggml_mul (ctx0, up_proj, gate_proj);
726
736
cur = ggml_mul_mat (ctx0, model.layers [il].ff_down_w , cur);
727
737
}
@@ -732,14 +742,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
732
742
embeddings = cur;
733
743
}
734
744
735
- // LlavaMultiModalProjector (with GELU activation)
745
+ // mistral small 3.1 patch merger
746
+ // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
747
+ if (model.mm_patch_merger_w ) {
748
+ GGML_ASSERT (hparams.spatial_merge_size > 0 );
749
+
750
+ ggml_tensor * cur = embeddings;
751
+ cur = ggml_mul (ctx0, ggml_rms_norm (ctx0, cur, eps), model.mm_input_norm_w );
752
+
753
+ // reshape image tokens to 2D grid
754
+ cur = ggml_reshape_3d (ctx0, cur, hidden_size, n_patches_x, n_patches_y);
755
+ cur = ggml_permute (ctx0, cur, 2 , 0 , 1 , 3 ); // [x, y, hidden_size]
756
+ cur = ggml_cont (ctx0, cur);
757
+
758
+ // torch.nn.functional.unfold is just an im2col under the hood
759
+ // we just need a dummy kernel to make it work
760
+ ggml_tensor * kernel = ggml_view_3d (ctx0, cur, n_merge, n_merge, cur->ne [2 ], 0 , 0 , 0 );
761
+ cur = ggml_im2col (ctx0, kernel, cur, n_merge, n_merge, 0 , 0 , 1 , 1 , true , inp->type );
762
+
763
+ // project to hidden_size
764
+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ], cur->ne [1 ] * cur->ne [2 ]);
765
+ cur = ggml_mul_mat (ctx0, model.mm_patch_merger_w , cur);
766
+ embeddings = cur;
767
+ }
768
+
769
+ // LlavaMultiModalProjector (always using GELU activation)
736
770
{
737
771
embeddings = ggml_mul_mat (ctx0, model.mm_1_w , embeddings);
738
- embeddings = ggml_add (ctx0, embeddings, model.mm_1_b );
772
+ if (model.mm_1_b ) {
773
+ embeddings = ggml_add (ctx0, embeddings, model.mm_1_b );
774
+ }
739
775
740
776
embeddings = ggml_gelu (ctx0, embeddings);
741
777
embeddings = ggml_mul_mat (ctx0, model.mm_2_w , embeddings);
742
- embeddings = ggml_add (ctx0, embeddings, model.mm_2_b );
778
+ if (model.mm_2_b ) {
779
+ embeddings = ggml_add (ctx0, embeddings, model.mm_2_b );
780
+ }
743
781
}
744
782
745
783
// arrangement of the [IMG_BREAK] token
@@ -749,11 +787,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
749
787
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
750
788
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
751
789
790
+ const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
791
+ const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
792
+ const int p_total = p_x * p_y;
752
793
const int n_embd_text = embeddings->ne [0 ];
753
- const int n_tokens_output = num_patches + n_patches_y - 1 ; // one [IMG_BREAK] per row, except the last row
794
+ const int n_tokens_output = p_total + p_y - 1 ; // one [IMG_BREAK] per row, except the last row
754
795
755
- ggml_tensor * cur = ggml_reshape_3d (ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y );
756
- ggml_tensor * tok = ggml_new_tensor_3d (ctx0, embeddings->type , n_embd_text, 1 , n_patches_y );
796
+ ggml_tensor * cur = ggml_reshape_3d (ctx0, embeddings, n_embd_text, p_x, p_y );
797
+ ggml_tensor * tok = ggml_new_tensor_3d (ctx0, embeddings->type , n_embd_text, 1 , p_y );
757
798
tok = ggml_scale (ctx0, tok, 0.0 ); // clear the tensor
758
799
tok = ggml_add (ctx0, tok, model.token_embd_img_break );
759
800
cur = ggml_concat (ctx0, cur, tok, 1 );
@@ -1734,6 +1775,7 @@ struct clip_model_loader {
1734
1775
case PROJECTOR_TYPE_PIXTRAL:
1735
1776
{
1736
1777
hparams.rope_theta = 10000 .0f ;
1778
+ get_u32 (KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size , false );
1737
1779
} break ;
1738
1780
case PROJECTOR_TYPE_QWEN25VL:
1739
1781
{
@@ -1957,11 +1999,14 @@ struct clip_model_loader {
1957
1999
case PROJECTOR_TYPE_PIXTRAL:
1958
2000
{
1959
2001
vision_model.mm_1_w = get_tensor (string_format (TN_LLAVA_PROJ, 1 , " weight" ));
1960
- vision_model.mm_1_b = get_tensor (string_format (TN_LLAVA_PROJ, 1 , " bias" ));
2002
+ vision_model.mm_1_b = get_tensor (string_format (TN_LLAVA_PROJ, 1 , " bias" ), false );
1961
2003
vision_model.mm_2_w = get_tensor (string_format (TN_LLAVA_PROJ, 2 , " weight" ));
1962
- vision_model.mm_2_b = get_tensor (string_format (TN_LLAVA_PROJ, 2 , " bias" ));
2004
+ vision_model.mm_2_b = get_tensor (string_format (TN_LLAVA_PROJ, 2 , " bias" ), false );
1963
2005
// [IMG_BREAK] token embedding
1964
2006
vision_model.token_embd_img_break = get_tensor (TN_TOK_IMG_BREAK);
2007
+ // for mistral small 3.1
2008
+ vision_model.mm_input_norm_w = get_tensor (TN_MM_INP_NORM, false );
2009
+ vision_model.mm_patch_merger_w = get_tensor (TN_MM_PATCH_MERGER, false );
1965
2010
} break ;
1966
2011
default :
1967
2012
GGML_ASSERT (false && " unknown projector type" );
@@ -2926,8 +2971,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
2926
2971
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
2927
2972
n_patches /= ctx->vision_model .hparams .proj_scale_factor ;
2928
2973
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
2929
- int n_patches_x = img->nx / params.patch_size ;
2930
- int n_patches_y = img->ny / params.patch_size ;
2974
+ int n_merge = ctx->vision_model .hparams .spatial_merge_size ;
2975
+ int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1 );
2976
+ int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1 );
2931
2977
n_patches = n_patches_y*n_patches_x + n_patches_y - 1 ; // + one [IMG_BREAK] per row, except the last row
2932
2978
}
2933
2979
@@ -3484,7 +3530,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
3484
3530
return ctx->vision_model .mm_model_peg_0_b ->ne [0 ];
3485
3531
case PROJECTOR_TYPE_MLP:
3486
3532
case PROJECTOR_TYPE_PIXTRAL:
3487
- return ctx->vision_model .mm_2_b ->ne [0 ];
3533
+ return ctx->vision_model .mm_2_w ->ne [1 ];
3488
3534
case PROJECTOR_TYPE_MLP_NORM:
3489
3535
return ctx->vision_model .mm_3_b ->ne [0 ];
3490
3536
case PROJECTOR_TYPE_MINICPMV:
0 commit comments