Skip to content

Commit e72430f

Browse files
authored
Merge pull request #2803 from xiegengxin/AVX2-asum
Implementaion of dasum, sasum with AVX2 & AVX512 intrinsic
2 parents 6e0f6c5 + 1b0f17e commit e72430f

File tree

7 files changed

+495
-0
lines changed

7 files changed

+495
-0
lines changed

kernel/x86_64/KERNEL.HASWELL

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,5 @@ ZTRSMKERNEL_RT = ../generic/trsm_kernel_RT.c
100100
CGEMM3MKERNEL = cgemm3m_kernel_8x4_haswell.c
101101
ZGEMM3MKERNEL = zgemm3m_kernel_4x4_haswell.c
102102

103+
SASUMKERNEL = sasum.c
104+
DASUMKERNEL = dasum.c

kernel/x86_64/dasum.c

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "common.h"
2+
3+
#ifndef ABS_K
4+
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
5+
#endif
6+
7+
#if defined(SKYLAKEX)
8+
#include "dasum_microk_skylakex-2.c"
9+
#elif defined(HASWELL)
10+
#include "dasum_microk_haswell-2.c"
11+
#endif
12+
13+
#ifndef HAVE_DASUM_KERNEL
14+
static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
15+
{
16+
17+
BLASLONG i=0;
18+
BLASLONG n_8 = n & -8;
19+
FLOAT *x = x1;
20+
FLOAT temp0, temp1, temp2, temp3;
21+
FLOAT temp4, temp5, temp6, temp7;
22+
FLOAT sum0 = 0.0;
23+
FLOAT sum1 = 0.0;
24+
FLOAT sum2 = 0.0;
25+
FLOAT sum3 = 0.0;
26+
FLOAT sum4 = 0.0;
27+
28+
while (i < n_8) {
29+
temp0 = ABS_K(x[0]);
30+
temp1 = ABS_K(x[1]);
31+
temp2 = ABS_K(x[2]);
32+
temp3 = ABS_K(x[3]);
33+
temp4 = ABS_K(x[4]);
34+
temp5 = ABS_K(x[5]);
35+
temp6 = ABS_K(x[6]);
36+
temp7 = ABS_K(x[7]);
37+
38+
sum0 += temp0;
39+
sum1 += temp1;
40+
sum2 += temp2;
41+
sum3 += temp3;
42+
43+
sum0 += temp4;
44+
sum1 += temp5;
45+
sum2 += temp6;
46+
sum3 += temp7;
47+
48+
x+=8;
49+
i+=8;
50+
}
51+
52+
while (i < n) {
53+
sum4 += ABS_K(x1[i]);
54+
i++;
55+
}
56+
57+
return sum0+sum1+sum2+sum3+sum4;
58+
}
59+
60+
#endif
61+
62+
FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
63+
{
64+
BLASLONG i=0;
65+
FLOAT sumf = 0.0;
66+
67+
if (n <= 0 || inc_x <= 0) return(sumf);
68+
69+
if ( inc_x == 1 ) {
70+
sumf = dasum_kernel(n, x);
71+
}
72+
else {
73+
n *= inc_x;
74+
75+
while(i < n) {
76+
sumf += ABS_K(x[i]);
77+
i += inc_x;
78+
}
79+
}
80+
return(sumf);
81+
}
82+
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#if (( defined(__GNUC__) && __GNUC__ > 6 ) || (defined(__clang__) && __clang_major__ >= 6)) && defined(__AVX2__)
2+
3+
#define HAVE_DASUM_KERNEL
4+
5+
#include <immintrin.h>
6+
#include <stdint.h>
7+
8+
#ifndef ABS_K
9+
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
10+
#endif
11+
12+
static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
13+
{
14+
BLASLONG i = 0;
15+
FLOAT sumf = 0.0;
16+
17+
if (n >= 256) {
18+
BLASLONG align_256 = ((32 - ((uintptr_t)x1 & (uintptr_t)0x1f)) >> 3) & 0x3;
19+
20+
for (i = 0; i < align_256; i++) {
21+
sumf += ABS_K(x1[i]);
22+
}
23+
24+
n -= align_256;
25+
x1 += align_256;
26+
}
27+
28+
BLASLONG tail_index_SSE = n&(~7);
29+
BLASLONG tail_index_AVX2 = n&(~255);
30+
31+
if (n >= 256) {
32+
__m256d accum_0, accum_1, accum_2, accum_3;
33+
34+
accum_0 = _mm256_setzero_pd();
35+
accum_1 = _mm256_setzero_pd();
36+
accum_2 = _mm256_setzero_pd();
37+
accum_3 = _mm256_setzero_pd();
38+
39+
__m256i abs_mask = _mm256_set1_epi64x(0x7fffffffffffffff);
40+
for (i = 0; i < tail_index_AVX2; i += 16) {
41+
accum_0 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 0]), abs_mask);
42+
accum_1 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 4]), abs_mask);
43+
accum_2 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 8]), abs_mask);
44+
accum_3 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+12]), abs_mask);
45+
}
46+
47+
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
48+
49+
__m128d half_accum0;
50+
half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1));
51+
52+
half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);
53+
54+
sumf += half_accum0[0];
55+
}
56+
57+
if (n >= 8) {
58+
__m128d accum_20, accum_21, accum_22, accum_23;
59+
accum_20 = _mm_setzero_pd();
60+
accum_21 = _mm_setzero_pd();
61+
accum_22 = _mm_setzero_pd();
62+
accum_23 = _mm_setzero_pd();
63+
64+
__m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff);
65+
for (i = tail_index_AVX2; i < tail_index_SSE; i += 8) {
66+
accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2);
67+
accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2);
68+
accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2);
69+
accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2);
70+
}
71+
72+
accum_20 = accum_20 + accum_21 + accum_22 + accum_23;
73+
__m128d half_accum20;
74+
half_accum20 = _mm_hadd_pd(accum_20, accum_20);
75+
76+
sumf += half_accum20[0];
77+
}
78+
79+
for (i = tail_index_SSE; i < n; ++i) {
80+
sumf += ABS_K(x1[i]);
81+
}
82+
83+
return sumf;
84+
85+
}
86+
#endif
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* need a new enough GCC for avx512 support */
2+
#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
3+
4+
#define HAVE_DASUM_KERNEL 1
5+
6+
#include <immintrin.h>
7+
8+
#include <stdint.h>
9+
10+
#ifndef ABS_K
11+
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
12+
#endif
13+
14+
static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
15+
{
16+
BLASLONG i = 0;
17+
FLOAT sumf = 0.0;
18+
19+
if (n >= 256) {
20+
BLASLONG align_512 = ((64 - ((uintptr_t)x1 & (uintptr_t)0x3f)) >> 3) & 0x7;
21+
22+
for (i = 0; i < align_512; i++) {
23+
sumf += ABS_K(x1[i]);
24+
}
25+
26+
n -= align_512;
27+
x1 += align_512;
28+
}
29+
30+
BLASLONG tail_index_SSE = n&(~7);
31+
BLASLONG tail_index_AVX512 = n&(~255);
32+
33+
//
34+
if ( n >= 256 ) {
35+
36+
__m512d accum_0, accum_1, accum_2, accum_3;
37+
accum_0 = _mm512_setzero_pd();
38+
accum_1 = _mm512_setzero_pd();
39+
accum_2 = _mm512_setzero_pd();
40+
accum_3 = _mm512_setzero_pd();
41+
for (i = 0; i < tail_index_AVX512; i += 32) {
42+
accum_0 += _mm512_abs_pd(_mm512_load_pd(&x1[i + 0]));
43+
accum_1 += _mm512_abs_pd(_mm512_load_pd(&x1[i + 8]));
44+
accum_2 += _mm512_abs_pd(_mm512_load_pd(&x1[i +16]));
45+
accum_3 += _mm512_abs_pd(_mm512_load_pd(&x1[i +24]));
46+
}
47+
48+
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
49+
sumf += _mm512_reduce_add_pd(accum_0);
50+
}
51+
52+
if (n >= 8) {
53+
__m128d accum_20, accum_21, accum_22, accum_23;
54+
accum_20 = _mm_setzero_pd();
55+
accum_21 = _mm_setzero_pd();
56+
accum_22 = _mm_setzero_pd();
57+
accum_23 = _mm_setzero_pd();
58+
59+
__m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff);
60+
for (i = tail_index_AVX512; i < tail_index_SSE; i += 8) {
61+
accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2);
62+
accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2);
63+
accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2);
64+
accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2);
65+
}
66+
67+
accum_20 = accum_20 + accum_21 + accum_22 + accum_23;
68+
__m128d half_accum20;
69+
half_accum20 = _mm_hadd_pd(accum_20, accum_20);
70+
71+
sumf += half_accum20[0];
72+
}
73+
74+
for (i = tail_index_SSE; i < n; ++i) {
75+
sumf += ABS_K(x1[i]);
76+
}
77+
78+
return sumf;
79+
}
80+
#endif

kernel/x86_64/sasum.c

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#include "common.h"
2+
3+
#if defined(DOUBLE)
4+
#error supports float only
5+
#else
6+
#ifndef ABS_K
7+
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
8+
#endif
9+
10+
#endif
11+
12+
#if defined(SKYLAKEX)
13+
#include "sasum_microk_skylakex-2.c"
14+
#elif defined(HASWELL)
15+
#include "sasum_microk_haswell-2.c"
16+
#endif
17+
18+
#ifndef HAVE_SASUM_KERNEL
19+
20+
static FLOAT sasum_kernel(BLASLONG n, FLOAT *x1)
21+
{
22+
23+
BLASLONG i=0;
24+
BLASLONG n_8 = n & -8;
25+
FLOAT *x = x1;
26+
FLOAT temp0, temp1, temp2, temp3;
27+
FLOAT temp4, temp5, temp6, temp7;
28+
FLOAT sum0 = 0.0;
29+
FLOAT sum1 = 0.0;
30+
FLOAT sum2 = 0.0;
31+
FLOAT sum3 = 0.0;
32+
FLOAT sum4 = 0.0;
33+
34+
while (i < n_8) {
35+
36+
temp0 = ABS_K(x[0]);
37+
temp1 = ABS_K(x[1]);
38+
temp2 = ABS_K(x[2]);
39+
temp3 = ABS_K(x[3]);
40+
temp4 = ABS_K(x[4]);
41+
temp5 = ABS_K(x[5]);
42+
temp6 = ABS_K(x[6]);
43+
temp7 = ABS_K(x[7]);
44+
45+
sum0 += temp0;
46+
sum1 += temp1;
47+
sum2 += temp2;
48+
sum3 += temp3;
49+
50+
sum0 += temp4;
51+
sum1 += temp5;
52+
sum2 += temp6;
53+
sum3 += temp7;
54+
55+
x+=8;
56+
i+=8;
57+
58+
}
59+
60+
while (i < n) {
61+
sum4 += ABS_K(x1[i]);
62+
i++;
63+
}
64+
65+
return sum0+sum1+sum2+sum3+sum4;
66+
}
67+
68+
#endif
69+
70+
FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
71+
{
72+
BLASLONG i=0;
73+
FLOAT sumf = 0.0;
74+
75+
if (n <= 0 || inc_x <= 0) return(sumf);
76+
77+
if ( inc_x == 1 ) {
78+
sumf = sasum_kernel(n, x);
79+
}
80+
else {
81+
82+
n *= inc_x;
83+
while(i < n) {
84+
sumf += ABS_K(x[i]);
85+
i += inc_x;
86+
}
87+
88+
}
89+
return(sumf);
90+
}

0 commit comments

Comments
 (0)