@@ -606,12 +606,12 @@ typedef struct {
606
606
static_assert (sizeof (block_q2_0 ) == sizeof (ggml_fp16_t ) + QK2_0 / 4 , "wrong q2_0 size/padding" );
607
607
608
608
#define QK3_0 16
609
- typedef union {
610
- struct {
611
- uint16_t pad [ 3 ];
612
- ggml_fp16_t d ;
613
- };
614
- uint64_t qs ;
609
+ typedef struct {
610
+ ggml_fp16_t d ;
611
+ // Instead of representing q3_0 as a packed format "...210210210210",
612
+ // represent it as two planes: "...10101010" and "...2222"
613
+ uint16_t qhi ; // The highest bit of each 3-bit number, packed together
614
+ uint32_t qlo ; // The low 2-bits of each 3-bit number, packed together
615
615
} block_q3_0 ;
616
616
static_assert (sizeof (block_q3_0 ) == sizeof (ggml_fp16_t ) + QK3_0 * 3 / 8 , "wrong q3_0 size/padding" );
617
617
@@ -691,17 +691,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y,
691
691
const float d = max / -4 ;
692
692
const float id = d ? 1.0f /d : 0.0f ;
693
693
694
- uint64_t qs = 0 ;
694
+ uint32_t lo = 0 ;
695
+ uint16_t hi = 0 ;
695
696
696
697
for (int l = 0 ; l < QK3_0 ; l ++ ) {
697
698
const float v = x [i * QK3_0 + l ]* id ;
698
699
const uint8_t vi = MIN (7 , (int8_t )roundf (v ) + 4 );
699
700
assert (vi < 8 );
700
- qs |= (uint64_t )vi << (l * 3 );
701
+ lo |= (vi & 3 ) << (l * 2 );
702
+ hi |= ((vi >> 2 ) & 1 ) << l ;
701
703
}
702
704
703
- y [i ].qs = qs ;
704
- y [i ].d = GGML_FP32_TO_FP16 (d ); // overwrite unused part of uint64_t qs
705
+ y [i ].d = GGML_FP32_TO_FP16 (d );
706
+ y [i ].qlo = lo ;
707
+ y [i ].qhi = hi ;
705
708
}
706
709
}
707
710
@@ -1335,13 +1338,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in
1335
1338
1336
1339
for (int i = 0 ; i < nb ; i ++ ) {
1337
1340
const float d = GGML_FP16_TO_FP32 (x [i ].d );
1338
- uint64_t qs = x [i ].qs ;
1341
+ uint_fast32_t lo = x [i ].qlo ;
1342
+ uint_fast32_t hi = x [i ].qhi << 2 ;
1339
1343
for (int l = 0 ; l < QK3_0 ; l ++ ) {
1340
- const int8_t vi = qs & 7 ;
1344
+ const int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
1341
1345
const float v = (vi - 4 )* d ;
1342
1346
y [i * QK3_0 + l ] = v ;
1343
1347
assert (!isnan (y [i * QK3_0 + l ]));
1344
- qs >>= 3 ;
1348
+ lo >>= 2 ;
1349
+ hi >>= 1 ;
1345
1350
}
1346
1351
}
1347
1352
}
@@ -2193,6 +2198,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2193
2198
* s = sumf ;
2194
2199
}
2195
2200
2201
+ #if __AVX2__ || __AVX512F__
2202
+ // Computes the dot product of signed 8-bit integers packed into 256-bit vectors,
2203
+ // converting the result to 32-bit floats packed into a 256-bit vector.
2204
+ static inline __m256 dotMul (__m256i bx , __m256i by ) {
2205
+ # if __AVXVNNIINT8__
2206
+ // Perform multiplication and sum to 32-bit values
2207
+ const __m256i i32 = _mm256_dpbssd_epi32 (bx , by , _mm256_setzero_si256 ());
2208
+ # else
2209
+ // Get absolute values of x vectors
2210
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2211
+ // Sign the values of the y vectors
2212
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2213
+ // Perform multiplication and create 16-bit values
2214
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2215
+
2216
+ // Convert int16_t to int32_t by adding pairwise
2217
+ const __m256i ones = _mm256_set1_epi16 (1 );
2218
+ const __m256i i32 = _mm256_madd_epi16 (ones , dot );
2219
+ # endif
2220
+ // Convert int32_t to float
2221
+ return _mm256_cvtepi32_ps (i32 );
2222
+ }
2223
+
2224
+ // Return horizontal sum of 32-bit floats packed into a 256-bit vector.
2225
+ static inline float horizontalSum (__m256 acc ) {
2226
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2227
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2228
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2229
+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2230
+ return _mm_cvtss_f32 (res );
2231
+ }
2232
+ #endif
2233
+
2196
2234
static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2197
2235
assert (n % QK2_0 == 0 );
2198
2236
const int nb = n / QK2_0 ;
@@ -2222,30 +2260,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2222
2260
// Load y vector
2223
2261
const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2224
2262
2225
- // Get absolute values of x vectors
2226
- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2227
- // Sign the values of the y vectors
2228
- const __m256i sy = _mm256_sign_epi8 (by , bx );
2229
- // Perform multiplication and create 16-bit values
2230
- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2231
-
2232
- // Convert int16_t to int32_t by adding pairwise
2233
- const __m256i ones = _mm256_set1_epi16 (1 );
2234
- __m256i i32 = _mm256_madd_epi16 (ones , dot );
2235
-
2236
- // Convert int32_t to float
2237
- __m256 p = _mm256_cvtepi32_ps (i32 );
2263
+ // Do the product:
2264
+ __m256 p = dotMul (bx , by );
2238
2265
2239
2266
// Apply the scale, and accumulate
2240
2267
acc = _mm256_fmadd_ps (scale , p , acc );
2241
2268
}
2242
2269
2243
2270
// Return horizontal sum of the acc vector
2244
- __m128 res = _mm256_extractf128_ps (acc , 1 );
2245
- res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2246
- res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2247
- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2248
- sumf = _mm_cvtss_f32 (res );
2271
+ sumf = horizontalSum (acc );
2249
2272
#else
2250
2273
for (int i = 0 ; i < nb ; i ++ ) {
2251
2274
const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
@@ -2270,6 +2293,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2270
2293
* s = sumf ;
2271
2294
}
2272
2295
2296
+ // Lookup table used to convert q3_0 to SIMD vectors.
2297
+ // Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte.
2298
+ // A zero bit turns into 0xFC, while a one bit turns into 0x00.
2299
+ #define B0 (n ) 0x ## n
2300
+ #define B1 (n ) B0(n ## FC), B0(n ## 00)
2301
+ #define B2 (n ) B1(n ## FC), B1(n ## 00)
2302
+ #define B3 (n ) B2(n ## FC), B2(n ## 00)
2303
+ #define B4 (n ) B3(n ## FC), B3(n ## 00)
2304
+ #define B5 (n ) B4(n ## FC), B4(n ## 00)
2305
+ #define B6 (n ) B5(n ## FC), B5(n ## 00)
2306
+ #define B7 (n ) B6(n ## FC), B6(n ## 00)
2307
+ #define B8 ( ) B7( FC), B7( 00)
2308
+ static const uint64_t ggml_q3_table [256 ] = { B8 () };
2309
+
2273
2310
static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2274
2311
assert (n % QK3_0 == 0 );
2275
2312
const int nb = n / QK3_0 ;
@@ -2282,103 +2319,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
2282
2319
2283
2320
#if defined(__AVX2__ )
2284
2321
// Initialize accumulator with zeros
2285
- __m128 acc = _mm_setzero_ps ();
2322
+ __m256 acc = _mm256_setzero_ps ();
2323
+
2286
2324
for (int i = 0 ; i < nb /2 ; i ++ ) {
2287
- const __m128 scale_y = _mm_set1_ps (y [i ].d );
2288
- for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2289
- // Compute combined scale for the block
2290
- const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2291
- const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2292
-
2293
- __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2294
-
2295
- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2296
-
2297
- // shift the copies to be able to reach all values
2298
- // 255 192 128 64 0
2299
- // | | | |
2300
- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2301
- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2302
- // _______________________sssssfedcba98765432__________________________________________ shift right
2303
- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2304
- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2305
- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2306
- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2307
- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2308
- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2309
-
2310
- // add to itself in masked places to shift some values left one bit
2311
- // 127 64 0
2312
- // | | | | | | | | | | | | | | | |
2313
- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2314
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2315
- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2316
- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2317
- //
2318
- // 255 192 128
2319
- // | | | | | | | | | | | | | | | |
2320
- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2321
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2322
- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2323
- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2324
- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2325
- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2326
-
2327
- // collect 16 bytes from 256 into 128 bits
2328
- const __m256i shufmask = _mm256_set_epi8 (
2329
- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2330
- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2331
- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2332
-
2333
- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2334
-
2335
- const __m128i mask = _mm_set1_epi8 (7 );
2336
- bx = _mm_and_si128 (mask , bx );
2337
-
2338
- const __m128i off = _mm_set1_epi8 (4 );
2339
- bx = _mm_sub_epi8 (bx , off );
2340
-
2341
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2325
+ __m256i bx = bytesFromCrumbs (x [i * 2 + 1 ].qlo , x [i * 2 ].qlo );
2342
2326
2343
- // Get absolute values of x vectors
2344
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2345
- // Sign the values of the y vectors
2346
- const __m128i sy = _mm_sign_epi8 (by , bx );
2347
- // Perform multiplication and create 16-bit values
2348
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2327
+ __m256i const bxhi = _mm256_set_epi64x (
2328
+ ggml_q3_table [x [i * 2 + 1 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 1 ].qhi & 0xFF ],
2329
+ ggml_q3_table [x [i * 2 + 0 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 0 ].qhi & 0xFF ]);
2349
2330
2350
- // Convert int16_t to int32_t by adding pairwise
2351
- const __m128i ones = _mm_set1_epi16 (1 );
2352
- __m128i i32 = _mm_madd_epi16 (dot , ones );
2331
+ // OR the high bits (which also handles the sign):
2332
+ bx = _mm256_or_si256 (bx , bxhi );
2353
2333
2354
- // Convert int32_t to float
2355
- const __m128 p = _mm_cvtepi32_ps (i32 );
2334
+ // Compute combined scale for the block
2335
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2336
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2337
+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2338
+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
2356
2339
2357
- // Apply the scale, and accumulate
2358
- acc = _mm_fmadd_ps (scale , p , acc );
2359
- }
2340
+ // Load y vector
2341
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2342
+
2343
+ // Do the product,
2344
+ __m256 p = dotMul (bx , by );
2345
+
2346
+ // Apply the scale, and accumulate
2347
+ acc = _mm256_fmadd_ps (scale , p , acc );
2360
2348
}
2361
2349
2362
2350
// Return horizontal sum of the acc vector
2363
- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2364
- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2365
- sumf = _mm_cvtss_f32 (res );
2351
+ sumf = horizontalSum (acc );
2366
2352
#else
2367
2353
for (int i = 0 ; i < nb ; i ++ ) {
2368
2354
const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
2369
2355
const float d1 = y [i /2 ].d ;
2370
2356
2371
- uint64_t qs0 = x [i ].qs ;
2357
+ uint_fast32_t lo0 = x [i ].qlo ;
2358
+ uint_fast32_t hi0 = x [i ].qhi << 2 ;
2372
2359
const int8_t * restrict p1 = y [i /2 ].qs + (i %2 )* QK3_0 ;
2373
2360
2374
2361
int sumi = 0 ;
2375
- for (int j = 0 ; j < QK3_0 ; j ++ ) {
2376
- const int8_t i0 = (int8_t )(qs0 & 7 ) - 4 ;
2377
- const int_fast16_t i1 = p1 [j ];
2362
+ for (int l = 0 ; l < QK3_0 ; l ++ ) {
2363
+ const int8_t i0 = (int8_t )(( lo0 & 3 ) | (( hi0 & 4 ) - 4 )) ;
2364
+ const int_fast16_t i1 = p1 [l ];
2378
2365
2379
2366
sumi += i0 * i1 ;
2380
2367
2381
- qs0 >>= 3 ;
2368
+ lo0 >>= 2 ;
2369
+ hi0 >>= 1 ;
2382
2370
}
2383
2371
sumf += d0 * d1 * sumi ;
2384
2372
}
@@ -11630,11 +11618,13 @@ size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t h
11630
11618
quantize_row_q3_0 (src + j , y , k );
11631
11619
11632
11620
for (int i = 0 ; i < nb ; i ++ ) {
11633
- uint64_t qs = y [i ].qs ;
11621
+ uint_fast32_t lo = y [i ].qlo ;
11622
+ uint_fast32_t hi = y [i ].qhi << 2 ;
11634
11623
for (int l = 0 ; l < QK3_0 ; l ++ ) {
11635
- const int8_t vi = qs & 7 ;
11624
+ int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
11636
11625
hist [vi ]++ ;
11637
- qs >>= 3 ;
11626
+ lo >>= 2 ;
11627
+ hi >>= 1 ;
11638
11628
}
11639
11629
}
11640
11630
}
0 commit comments