@@ -2288,31 +2288,17 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2288
2288
const uint8_t * restrict p0 = x [i ].qs ;
2289
2289
const uint8_t * restrict p1 = y [i ].qs ;
2290
2290
2291
- for (int j = 0 ; j < QK /4 ; j ++ ) {
2292
- const uint32_t v0 = ((uint32_t * )p0 )[j ];
2293
- const uint32_t v1 = ((uint32_t * )p1 )[j ];
2294
-
2295
- const uint8_t v0_0 = (v0 >> 0 ) & 0xf ;
2296
- const uint8_t v0_1 = (v0 >> 4 ) & 0xf ;
2297
- const uint8_t v0_2 = (v0 >> 8 ) & 0xf ;
2298
- const uint8_t v0_3 = (v0 >> 12 ) & 0xf ;
2299
-
2300
- const uint8_t v1_0 = (v1 >> 0 ) & 0xf ;
2301
- const uint8_t v1_1 = (v1 >> 4 ) & 0xf ;
2302
- const uint8_t v1_2 = (v1 >> 8 ) & 0xf ;
2303
- const uint8_t v1_3 = (v1 >> 12 ) & 0xf ;
2304
-
2305
- const float f0 = d0 * v0_0 + m0 ;
2306
- const float f1 = d0 * v0_1 + m0 ;
2307
- const float f2 = d0 * v0_2 + m0 ;
2308
- const float f3 = d0 * v0_3 + m0 ;
2309
-
2310
- const float f4 = d1 * v1_0 + m1 ;
2311
- const float f5 = d1 * v1_1 + m1 ;
2312
- const float f6 = d1 * v1_2 + m1 ;
2313
- const float f7 = d1 * v1_3 + m1 ;
2314
-
2315
- sumf += f0 * f4 + f1 * f5 + f2 * f6 + f3 * f7 ;
2291
+ for (int j = 0 ; j < QK /2 ; j ++ ) {
2292
+ const uint8_t v0 = p0 [j ];
2293
+ const uint8_t v1 = p1 [j ];
2294
+
2295
+ const float f0 = d0 * (v0 & 0xf ) + m0 ;
2296
+ const float f1 = d0 * (v0 >> 4 ) + m0 ;
2297
+
2298
+ const float f2 = d1 * (v1 & 0xf ) + m1 ;
2299
+ const float f3 = d1 * (v1 >> 4 ) + m1 ;
2300
+
2301
+ sumf += f0 * f2 + f1 * f3 ;
2316
2302
}
2317
2303
}
2318
2304
#endif
0 commit comments