@@ -1763,42 +1763,63 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da
17631763 const int nb = n/QK_K ;
17641764
17651765 Q8 <nrc_y> q8 (info);
1766- DequantizerIQ2TN deq (vx, bx);
1766+ DequantizerIQ2TN deq1 (vx, bx), deq2 (vx, bx);
17671767
17681768 __m256 accd[nrc_y];
17691769 const auto m1 = _mm256_set1_epi16 (1 );
17701770
17711771 for (int ix = 0 ; ix < nrc_x; ++ix) {
17721772
1773- deq.new_row (ix);
1773+ deq1.new_row (ix);
1774+ deq2.new_row (ix);
17741775
17751776 for (int i = 0 ; i < nb; ++i) {
17761777
1777- __m256i sumi[nrc_y];
1778- deq.new_block (i);
1778+ deq1.new_block (i);
17791779
1780- deq.prepare (i, 0 );
1781- for (int iy = 0 ; iy < nrc_y; ++iy) {
1782- sumi[iy] = _mm256_add_epi16 (_mm256_maddubs_epi16 (deq.bits .values [0 ], q8.load_quants (iy, i, 0 )),
1783- _mm256_maddubs_epi16 (deq.bits .values [1 ], q8.load_quants (iy, i, 1 )));
1784- sumi[iy] = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq.bits .values [2 ], q8.load_quants (iy, i, 2 )),
1785- _mm256_maddubs_epi16 (deq.bits .values [3 ], q8.load_quants (iy, i, 3 ))), sumi[iy]);
1780+ if constexpr (nrc_y == 1 ) {
1781+ deq1.prepare (i, 0 );
1782+ auto sumi1 = _mm256_add_epi16 (_mm256_maddubs_epi16 (deq1.bits .values [0 ], q8.load_quants (0 , i, 0 )),
1783+ _mm256_maddubs_epi16 (deq1.bits .values [1 ], q8.load_quants (0 , i, 1 )));
1784+ sumi1 = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq1.bits .values [2 ], q8.load_quants (0 , i, 2 )),
1785+ _mm256_maddubs_epi16 (deq1.bits .values [3 ], q8.load_quants (0 , i, 3 ))), sumi1);
1786+
1787+ deq2.prepare (i, 1 );
1788+ auto sumi2 = _mm256_add_epi16 (_mm256_maddubs_epi16 (deq2.bits .values [0 ], q8.load_quants (0 , i, 4 )),
1789+ _mm256_maddubs_epi16 (deq2.bits .values [1 ], q8.load_quants (0 , i, 5 )));
1790+ sumi2 = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq2.bits .values [2 ], q8.load_quants (0 , i, 6 )),
1791+ _mm256_maddubs_epi16 (deq2.bits .values [3 ], q8.load_quants (0 , i, 7 ))), sumi2);
1792+ auto sumi = _mm256_add_epi16 (sumi2, _mm256_sub_epi16 (sumi1, q8.load_bsums (0 , i)));
1793+ auto vd = _mm256_set1_ps (deq1.d *q8.scale (0 , i));
1794+ auto sf = _mm256_cvtepi32_ps (_mm256_madd_epi16 (m1, sumi));
1795+ accd[0 ] = i > 0 ? _mm256_fmadd_ps (vd, sf, accd[0 ]) : _mm256_mul_ps (vd, sf);
17861796 }
1787- deq.prepare (i, 1 );
1788- for (int iy = 0 ; iy < nrc_y; ++iy) {
1789- sumi[iy] = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq.bits .values [0 ], q8.load_quants (iy, i, 4 )),
1790- _mm256_maddubs_epi16 (deq.bits .values [1 ], q8.load_quants (iy, i, 5 ))), sumi[iy]);
1791- sumi[iy] = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq.bits .values [2 ], q8.load_quants (iy, i, 6 )),
1792- _mm256_maddubs_epi16 (deq.bits .values [3 ], q8.load_quants (iy, i, 7 ))), sumi[iy]);
1793- sumi[iy] = _mm256_sub_epi16 (sumi[iy], q8.load_bsums (iy, i));
1794- }
1795- if (i > 0 ) {
1796- for (int iy = 0 ; iy < nrc_y; ++iy) {
1797- accd[iy] = _mm256_fmadd_ps (_mm256_set1_ps (deq.d *q8.scale (iy, i)), _mm256_cvtepi32_ps (_mm256_madd_epi16 (m1, sumi[iy])), accd[iy]);
1798- }
1799- } else {
1797+ else {
1798+
1799+ deq1.prepare (i, 0 ); deq2.prepare (i, 1 );
18001800 for (int iy = 0 ; iy < nrc_y; ++iy) {
1801- accd[iy] = _mm256_mul_ps (_mm256_set1_ps (deq.d *q8.scale (iy, i)), _mm256_cvtepi32_ps (_mm256_madd_epi16 (m1, sumi[iy])));
1801+ auto vd = _mm256_set1_ps (deq1.d *q8.scale (iy, i));
1802+ auto sumi = _mm256_add_epi16 (_mm256_maddubs_epi16 (deq1.bits .values [0 ], q8.load_quants (iy, i, 0 )),
1803+ _mm256_maddubs_epi16 (deq1.bits .values [1 ], q8.load_quants (iy, i, 1 )));
1804+ sumi = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq1.bits .values [2 ], q8.load_quants (iy, i, 2 )),
1805+ _mm256_maddubs_epi16 (deq1.bits .values [3 ], q8.load_quants (iy, i, 3 ))), sumi);
1806+ sumi = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq2.bits .values [0 ], q8.load_quants (iy, i, 4 )),
1807+ _mm256_maddubs_epi16 (deq2.bits .values [1 ], q8.load_quants (iy, i, 5 ))), sumi);
1808+ sumi = _mm256_add_epi16 (_mm256_add_epi16 (_mm256_maddubs_epi16 (deq2.bits .values [2 ], q8.load_quants (iy, i, 6 )),
1809+ _mm256_maddubs_epi16 (deq2.bits .values [3 ], q8.load_quants (iy, i, 7 ))), sumi);
1810+ sumi = _mm256_sub_epi16 (sumi, q8.load_bsums (iy, i));
1811+
1812+ // auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)),
1813+ // _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1)));
1814+ // auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)),
1815+ // _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3)));
1816+ // sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)),
1817+ // _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi1);
1818+ // sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)),
1819+ // _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi2);
1820+ // auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(iy, i)));
1821+ auto sf = _mm256_cvtepi32_ps (_mm256_madd_epi16 (m1, sumi));
1822+ accd[iy] = i > 0 ? _mm256_fmadd_ps (vd, sf, accd[iy]) : _mm256_mul_ps (vd, sf);
18021823 }
18031824 }
18041825
@@ -3671,9 +3692,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
36713692 mm.funcs [2 ] = mul_mat_iq2tn_q8_K<3 >;
36723693 mm.funcs [3 ] = mul_mat_iq2tn_q8_K<4 >;
36733694 mm.funcs [4 ] = mul_mat_iq2tn_q8_K<5 >;
3674- // mm.funcs[5] = mul_mat_iq2tn_q8_K<6>;
3675- // mm.funcs[6] = mul_mat_iq2tn_q8_K<7>;
3676- // mm.funcs[7] = mul_mat_iq2tn_q8_K<8>;
3695+ mm.funcs [5 ] = mul_mat_iq2tn_q8_K<6 >;
3696+ mm.funcs [6 ] = mul_mat_iq2tn_q8_K<7 >;
3697+ mm.funcs [7 ] = mul_mat_iq2tn_q8_K<8 >;
36773698#endif
36783699 break ;
36793700 case GGML_TYPE_Q3_K :
0 commit comments