Skip to content

Commit 64a6a29

Browse files
committed
q4_0c: Arm Neon acceleration
Mostly copied from the q4_0 implementation
1 parent 96363e7 commit 64a6a29

File tree

1 file changed

+94
-2
lines changed

1 file changed

+94
-2
lines changed

ggml.c

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,7 +1539,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int
15391539
int8_t * restrict qs = vy;
15401540
float * restrict ds = (float *) ((uint8_t *) vy + nb*QK8_0C);
15411541

1542-
#if __AVX512F__
1542+
#if defined(__ARM_NEON)
1543+
for (int i = 0; i < nb; i++) {
1544+
float32x4_t srcv [8];
1545+
float32x4_t asrcv[8];
1546+
float32x4_t amaxv[8];
1547+
1548+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1549+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1550+
1551+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1552+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1553+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1554+
1555+
const float amax = vmaxvq_f32(amaxv[0]);
1556+
1557+
const float d = amax / ((1 << 7) - 1);
1558+
const float id = d ? 1.0f/d : 0.0f;
1559+
1560+
ds[i] = d;
1561+
1562+
for (int l = 0; l < 8; l++) {
1563+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
1564+
const int32x4_t vi = vcvtnq_s32_f32(v);
1565+
1566+
qs[i*QK8_0C + 4*l + 0] = vgetq_lane_s32(vi, 0);
1567+
qs[i*QK8_0C + 4*l + 1] = vgetq_lane_s32(vi, 1);
1568+
qs[i*QK8_0C + 4*l + 2] = vgetq_lane_s32(vi, 2);
1569+
qs[i*QK8_0C + 4*l + 3] = vgetq_lane_s32(vi, 3);
1570+
}
1571+
}
1572+
#elif defined(__AVX512F__)
15431573
for (int i = 0; i < nb; i++) {
15441574
const __m512 x0 = _mm512_loadu_ps( x + i*QK8_0C );
15451575
const __m512 x1 = _mm512_loadu_ps( x + i*QK8_0C + QK8_0C/2);
@@ -2817,7 +2847,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
28172847

28182848
float sumf = 0.0;
28192849

2820-
#if __AVX512F__
2850+
#if defined(__ARM_NEON)
2851+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
2852+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
2853+
2854+
for (int i = 0; i < nb/2; i++) {
2855+
const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ...
2856+
const int dst1 = i + i/2*2 + 2; // 2, 3, 6, 7, 10, 11 ...
2857+
2858+
const uint8x16_t m4b = vdupq_n_u8(0xf);
2859+
const int8x16_t s8b = vdupq_n_s8(0x8);
2860+
2861+
const uint8x16_t v0_01l = vld1q_u8(&xqs[i*QK4_0]);
2862+
const uint8x16_t v0_01h = vld1q_u8(&xqs[i*QK4_0 + QK4_0/2]);
2863+
2864+
// 4-bit -> 8-bit
2865+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_01l, m4b));
2866+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vandq_u8 (v0_01h, m4b));
2867+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vshrq_n_u8(v0_01l, 4));
2868+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_01h, 4));
2869+
2870+
// sub 8
2871+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2872+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2873+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2874+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2875+
2876+
// load y
2877+
const int8x16_t v1_0l = vld1q_s8(&yqs[dst0*QK8_0C]);
2878+
const int8x16_t v1_0h = vld1q_s8(&yqs[dst0*QK8_0C + 16]);
2879+
const int8x16_t v1_1l = vld1q_s8(&yqs[dst1*QK8_0C]);
2880+
const int8x16_t v1_1h = vld1q_s8(&yqs[dst1*QK8_0C + 16]);
2881+
2882+
#if defined(__ARM_FEATURE_DOTPROD)
2883+
// dot product into int32x4_t
2884+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2885+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2886+
2887+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), xds[dst0]*yds[dst0]);
2888+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), xds[dst1]*yds[dst1]);
2889+
#else
2890+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2891+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2892+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2893+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2894+
2895+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2896+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2897+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2898+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
2899+
2900+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2901+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2902+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2903+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2904+
2905+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), xds[dst0]*yds[dst0]);
2906+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), xds[dst1]*yds[dst1]);
2907+
#endif
2908+
}
2909+
2910+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2911+
2912+
#elif defined(__AVX512F__)
28212913
// Initialize accumulator with zeros
28222914
__m512 acc = _mm512_setzero_ps();
28232915
for (int i = 0; i < nb; i += 4) {

0 commit comments

Comments
 (0)