Skip to content

Commit 49cbbc9

Browse files
ikawrakowIwan Kawrakow
andauthored
iq2_tn: slightly better performance on AVX2 (ggml-org#47)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent e486643 commit 49cbbc9

1 file changed

Lines changed: 48 additions & 27 deletions

File tree

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,42 +1763,63 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da
17631763
const int nb = n/QK_K;
17641764

17651765
Q8<nrc_y> q8(info);
1766-
DequantizerIQ2TN deq(vx, bx);
1766+
DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx);
17671767

17681768
__m256 accd[nrc_y];
17691769
const auto m1 = _mm256_set1_epi16(1);
17701770

17711771
for (int ix = 0; ix < nrc_x; ++ix) {
17721772

1773-
deq.new_row(ix);
1773+
deq1.new_row(ix);
1774+
deq2.new_row(ix);
17741775

17751776
for (int i = 0; i < nb; ++i) {
17761777

1777-
__m256i sumi[nrc_y];
1778-
deq.new_block(i);
1778+
deq1.new_block(i);
17791779

1780-
deq.prepare(i, 0);
1781-
for (int iy = 0; iy < nrc_y; ++iy) {
1782-
sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)),
1783-
_mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1)));
1784-
sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)),
1785-
_mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 3))), sumi[iy]);
1780+
if constexpr (nrc_y == 1) {
1781+
deq1.prepare(i, 0);
1782+
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(0, i, 0)),
1783+
_mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(0, i, 1)));
1784+
sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(0, i, 2)),
1785+
_mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(0, i, 3))), sumi1);
1786+
1787+
deq2.prepare(i, 1);
1788+
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(0, i, 4)),
1789+
_mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(0, i, 5)));
1790+
sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(0, i, 6)),
1791+
_mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(0, i, 7))), sumi2);
1792+
auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(0, i)));
1793+
auto vd = _mm256_set1_ps(deq1.d*q8.scale(0, i));
1794+
auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi));
1795+
accd[0] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[0]) : _mm256_mul_ps(vd, sf);
17861796
}
1787-
deq.prepare(i, 1);
1788-
for (int iy = 0; iy < nrc_y; ++iy) {
1789-
sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 4)),
1790-
_mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 5))), sumi[iy]);
1791-
sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 6)),
1792-
_mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]);
1793-
sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i));
1794-
}
1795-
if (i > 0) {
1796-
for (int iy = 0; iy < nrc_y; ++iy) {
1797-
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]);
1798-
}
1799-
} else {
1797+
else {
1798+
1799+
deq1.prepare(i, 0); deq2.prepare(i, 1);
18001800
for (int iy = 0; iy < nrc_y; ++iy) {
1801-
accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])));
1801+
auto vd = _mm256_set1_ps(deq1.d*q8.scale(iy, i));
1802+
auto sumi = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)),
1803+
_mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1)));
1804+
sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)),
1805+
_mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3))), sumi);
1806+
sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)),
1807+
_mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi);
1808+
sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)),
1809+
_mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi);
1810+
sumi = _mm256_sub_epi16(sumi, q8.load_bsums(iy, i));
1811+
1812+
//auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)),
1813+
// _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1)));
1814+
//auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)),
1815+
// _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3)));
1816+
//sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)),
1817+
// _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi1);
1818+
//sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)),
1819+
// _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi2);
1820+
//auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(iy, i)));
1821+
auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi));
1822+
accd[iy] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[iy]) : _mm256_mul_ps(vd, sf);
18021823
}
18031824
}
18041825

@@ -3671,9 +3692,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
36713692
mm.funcs[2] = mul_mat_iq2tn_q8_K<3>;
36723693
mm.funcs[3] = mul_mat_iq2tn_q8_K<4>;
36733694
mm.funcs[4] = mul_mat_iq2tn_q8_K<5>;
3674-
//mm.funcs[5] = mul_mat_iq2tn_q8_K<6>;
3675-
//mm.funcs[6] = mul_mat_iq2tn_q8_K<7>;
3676-
//mm.funcs[7] = mul_mat_iq2tn_q8_K<8>;
3695+
mm.funcs[5] = mul_mat_iq2tn_q8_K<6>;
3696+
mm.funcs[6] = mul_mat_iq2tn_q8_K<7>;
3697+
mm.funcs[7] = mul_mat_iq2tn_q8_K<8>;
36773698
#endif
36783699
break;
36793700
case GGML_TYPE_Q3_K:

0 commit comments

Comments
 (0)