@@ -1539,7 +1539,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int
1539
1539
int8_t * restrict qs = vy ;
1540
1540
float * restrict ds = (float * ) ((uint8_t * ) vy + nb * QK8_0C );
1541
1541
1542
- #if __AVX512F__
1542
+ #if defined(__ARM_NEON )
1543
+ for (int i = 0 ; i < nb ; i ++ ) {
1544
+ float32x4_t srcv [8 ];
1545
+ float32x4_t asrcv [8 ];
1546
+ float32x4_t amaxv [8 ];
1547
+
1548
+ for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
1549
+ for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
1550
+
1551
+ for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
1552
+ for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
1553
+ for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
1554
+
1555
+ const float amax = vmaxvq_f32 (amaxv [0 ]);
1556
+
1557
+ const float d = amax / ((1 << 7 ) - 1 );
1558
+ const float id = d ? 1.0f /d : 0.0f ;
1559
+
1560
+ ds [i ] = d ;
1561
+
1562
+ for (int l = 0 ; l < 8 ; l ++ ) {
1563
+ const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1564
+ const int32x4_t vi = vcvtnq_s32_f32 (v );
1565
+
1566
+ qs [i * QK8_0C + 4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1567
+ qs [i * QK8_0C + 4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1568
+ qs [i * QK8_0C + 4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1569
+ qs [i * QK8_0C + 4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1570
+ }
1571
+ }
1572
+ #elif defined(__AVX512F__ )
1543
1573
for (int i = 0 ; i < nb ; i ++ ) {
1544
1574
const __m512 x0 = _mm512_loadu_ps ( x + i * QK8_0C );
1545
1575
const __m512 x1 = _mm512_loadu_ps ( x + i * QK8_0C + QK8_0C /2 );
@@ -2817,7 +2847,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
2817
2847
2818
2848
float sumf = 0.0 ;
2819
2849
2820
- #if __AVX512F__
2850
+ #if defined(__ARM_NEON )
2851
+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2852
+ float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2853
+
2854
+ for (int i = 0 ; i < nb /2 ; i ++ ) {
2855
+ const int dst0 = i + i /2 * 2 ; // 0, 1, 4, 5, 8, 9, ...
2856
+ const int dst1 = i + i /2 * 2 + 2 ; // 2, 3, 6, 7, 10, 11 ...
2857
+
2858
+ const uint8x16_t m4b = vdupq_n_u8 (0xf );
2859
+ const int8x16_t s8b = vdupq_n_s8 (0x8 );
2860
+
2861
+ const uint8x16_t v0_01l = vld1q_u8 (& xqs [i * QK4_0 ]);
2862
+ const uint8x16_t v0_01h = vld1q_u8 (& xqs [i * QK4_0 + QK4_0 /2 ]);
2863
+
2864
+ // 4-bit -> 8-bit
2865
+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_01l , m4b ));
2866
+ const int8x16_t v0_0h = vreinterpretq_s8_u8 (vandq_u8 (v0_01h , m4b ));
2867
+ const int8x16_t v0_1l = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_01l , 4 ));
2868
+ const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_01h , 4 ));
2869
+
2870
+ // sub 8
2871
+ const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2872
+ const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2873
+ const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2874
+ const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2875
+
2876
+ // load y
2877
+ const int8x16_t v1_0l = vld1q_s8 (& yqs [dst0 * QK8_0C ]);
2878
+ const int8x16_t v1_0h = vld1q_s8 (& yqs [dst0 * QK8_0C + 16 ]);
2879
+ const int8x16_t v1_1l = vld1q_s8 (& yqs [dst1 * QK8_0C ]);
2880
+ const int8x16_t v1_1h = vld1q_s8 (& yqs [dst1 * QK8_0C + 16 ]);
2881
+
2882
+ #if defined(__ARM_FEATURE_DOTPROD )
2883
+ // dot product into int32x4_t
2884
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0l ), v0_0hs , v1_0h );
2885
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1l ), v0_1hs , v1_1h );
2886
+
2887
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), xds [dst0 ]* yds [dst0 ]);
2888
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), xds [dst1 ]* yds [dst1 ]);
2889
+ #else
2890
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0l ));
2891
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0l ));
2892
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0h ));
2893
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0h ));
2894
+
2895
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1l ));
2896
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1l ));
2897
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1h ));
2898
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1h ));
2899
+
2900
+ const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2901
+ const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2902
+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2903
+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2904
+
2905
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), xds [dst0 ]* yds [dst0 ]);
2906
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (pl1 , ph1 )), xds [dst1 ]* yds [dst1 ]);
2907
+ #endif
2908
+ }
2909
+
2910
+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2911
+
2912
+ #elif defined(__AVX512F__ )
2821
2913
// Initialize accumulator with zeros
2822
2914
__m512 acc = _mm512_setzero_ps ();
2823
2915
for (int i = 0 ; i < nb ; i += 4 ) {
0 commit comments