Skip to content

Implementaion of dasum, sasum with AVX2 & AVX512 intrinsic #2803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kernel/x86_64/KERNEL.HASWELL
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,5 @@ ZTRSMKERNEL_RT = ../generic/trsm_kernel_RT.c
CGEMM3MKERNEL = cgemm3m_kernel_8x4_haswell.c
ZGEMM3MKERNEL = zgemm3m_kernel_4x4_haswell.c

SASUMKERNEL = sasum.c
DASUMKERNEL = dasum.c
82 changes: 82 additions & 0 deletions kernel/x86_64/dasum.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "common.h"

#ifndef ABS_K
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
#endif

#if defined(SKYLAKEX)
#include "dasum_microk_skylakex-2.c"
#elif defined(HASWELL)
#include "dasum_microk_haswell-2.c"
#endif

#ifndef HAVE_DASUM_KERNEL
static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
{

BLASLONG i=0;
BLASLONG n_8 = n & -8;
FLOAT *x = x1;
FLOAT temp0, temp1, temp2, temp3;
FLOAT temp4, temp5, temp6, temp7;
FLOAT sum0 = 0.0;
FLOAT sum1 = 0.0;
FLOAT sum2 = 0.0;
FLOAT sum3 = 0.0;
FLOAT sum4 = 0.0;

while (i < n_8) {
temp0 = ABS_K(x[0]);
temp1 = ABS_K(x[1]);
temp2 = ABS_K(x[2]);
temp3 = ABS_K(x[3]);
temp4 = ABS_K(x[4]);
temp5 = ABS_K(x[5]);
temp6 = ABS_K(x[6]);
temp7 = ABS_K(x[7]);

sum0 += temp0;
sum1 += temp1;
sum2 += temp2;
sum3 += temp3;

sum0 += temp4;
sum1 += temp5;
sum2 += temp6;
sum3 += temp7;

x+=8;
i+=8;
}

while (i < n) {
sum4 += ABS_K(x1[i]);
i++;
}

return sum0+sum1+sum2+sum3+sum4;
}

#endif

FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
{
BLASLONG i=0;
FLOAT sumf = 0.0;

if (n <= 0 || inc_x <= 0) return(sumf);

if ( inc_x == 1 ) {
sumf = dasum_kernel(n, x);
}
else {
n *= inc_x;

while(i < n) {
sumf += ABS_K(x[i]);
i += inc_x;
}
}
return(sumf);
}

86 changes: 86 additions & 0 deletions kernel/x86_64/dasum_microk_haswell-2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#if (( defined(__GNUC__) && __GNUC__ > 6 ) || (defined(__clang__) && __clang_major__ >= 6)) && defined(__AVX2__)

#define HAVE_DASUM_KERNEL

#include <immintrin.h>
#include <stdint.h>

#ifndef ABS_K
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
#endif

static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
{
BLASLONG i = 0;
FLOAT sumf = 0.0;

if (n >= 256) {
BLASLONG align_256 = ((32 - ((uintptr_t)x1 & (uintptr_t)0x1f)) >> 3) & 0x3;

for (i = 0; i < align_256; i++) {
sumf += ABS_K(x1[i]);
}

n -= align_256;
x1 += align_256;
}

BLASLONG tail_index_SSE = n&(~7);
BLASLONG tail_index_AVX2 = n&(~255);

if (n >= 256) {
__m256d accum_0, accum_1, accum_2, accum_3;

accum_0 = _mm256_setzero_pd();
accum_1 = _mm256_setzero_pd();
accum_2 = _mm256_setzero_pd();
accum_3 = _mm256_setzero_pd();

__m256i abs_mask = _mm256_set1_epi64x(0x7fffffffffffffff);
for (i = 0; i < tail_index_AVX2; i += 16) {
accum_0 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 0]), abs_mask);
accum_1 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 4]), abs_mask);
accum_2 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 8]), abs_mask);
accum_3 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+12]), abs_mask);
}

accum_0 = accum_0 + accum_1 + accum_2 + accum_3;

__m128d half_accum0;
half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1));

half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);

sumf += half_accum0[0];
}

if (n >= 8) {
__m128d accum_20, accum_21, accum_22, accum_23;
accum_20 = _mm_setzero_pd();
accum_21 = _mm_setzero_pd();
accum_22 = _mm_setzero_pd();
accum_23 = _mm_setzero_pd();

__m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff);
for (i = tail_index_AVX2; i < tail_index_SSE; i += 8) {
accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2);
accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2);
accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2);
accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2);
}

accum_20 = accum_20 + accum_21 + accum_22 + accum_23;
__m128d half_accum20;
half_accum20 = _mm_hadd_pd(accum_20, accum_20);

sumf += half_accum20[0];
}

for (i = tail_index_SSE; i < n; ++i) {
sumf += ABS_K(x1[i]);
}

return sumf;

}
#endif
80 changes: 80 additions & 0 deletions kernel/x86_64/dasum_microk_skylakex-2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* need a new enough GCC for avx512 support */
#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))

#define HAVE_DASUM_KERNEL 1

#include <immintrin.h>

#include <stdint.h>

#ifndef ABS_K
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
#endif

static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
{
BLASLONG i = 0;
FLOAT sumf = 0.0;

if (n >= 256) {
BLASLONG align_512 = ((64 - ((uintptr_t)x1 & (uintptr_t)0x3f)) >> 3) & 0x7;

for (i = 0; i < align_512; i++) {
sumf += ABS_K(x1[i]);
}

n -= align_512;
x1 += align_512;
}

BLASLONG tail_index_SSE = n&(~7);
BLASLONG tail_index_AVX512 = n&(~255);

//
if ( n >= 256 ) {

__m512d accum_0, accum_1, accum_2, accum_3;
accum_0 = _mm512_setzero_pd();
accum_1 = _mm512_setzero_pd();
accum_2 = _mm512_setzero_pd();
accum_3 = _mm512_setzero_pd();
for (i = 0; i < tail_index_AVX512; i += 32) {
accum_0 += _mm512_abs_pd(_mm512_load_pd(&x1[i + 0]));
accum_1 += _mm512_abs_pd(_mm512_load_pd(&x1[i + 8]));
accum_2 += _mm512_abs_pd(_mm512_load_pd(&x1[i +16]));
accum_3 += _mm512_abs_pd(_mm512_load_pd(&x1[i +24]));
}

accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
sumf += _mm512_reduce_add_pd(accum_0);
}

if (n >= 8) {
__m128d accum_20, accum_21, accum_22, accum_23;
accum_20 = _mm_setzero_pd();
accum_21 = _mm_setzero_pd();
accum_22 = _mm_setzero_pd();
accum_23 = _mm_setzero_pd();

__m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff);
for (i = tail_index_AVX512; i < tail_index_SSE; i += 8) {
accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2);
accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2);
accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2);
accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2);
}

accum_20 = accum_20 + accum_21 + accum_22 + accum_23;
__m128d half_accum20;
half_accum20 = _mm_hadd_pd(accum_20, accum_20);

sumf += half_accum20[0];
}

for (i = tail_index_SSE; i < n; ++i) {
sumf += ABS_K(x1[i]);
}

return sumf;
}
#endif
90 changes: 90 additions & 0 deletions kernel/x86_64/sasum.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include "common.h"

#if defined(DOUBLE)
#error supports float only
#else
#ifndef ABS_K
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
#endif

#endif

#if defined(SKYLAKEX)
#include "sasum_microk_skylakex-2.c"
#elif defined(HASWELL)
#include "sasum_microk_haswell-2.c"
#endif

#ifndef HAVE_SASUM_KERNEL

static FLOAT sasum_kernel(BLASLONG n, FLOAT *x1)
{

BLASLONG i=0;
BLASLONG n_8 = n & -8;
FLOAT *x = x1;
FLOAT temp0, temp1, temp2, temp3;
FLOAT temp4, temp5, temp6, temp7;
FLOAT sum0 = 0.0;
FLOAT sum1 = 0.0;
FLOAT sum2 = 0.0;
FLOAT sum3 = 0.0;
FLOAT sum4 = 0.0;

while (i < n_8) {

temp0 = ABS_K(x[0]);
temp1 = ABS_K(x[1]);
temp2 = ABS_K(x[2]);
temp3 = ABS_K(x[3]);
temp4 = ABS_K(x[4]);
temp5 = ABS_K(x[5]);
temp6 = ABS_K(x[6]);
temp7 = ABS_K(x[7]);

sum0 += temp0;
sum1 += temp1;
sum2 += temp2;
sum3 += temp3;

sum0 += temp4;
sum1 += temp5;
sum2 += temp6;
sum3 += temp7;

x+=8;
i+=8;

}

while (i < n) {
sum4 += ABS_K(x1[i]);
i++;
}

return sum0+sum1+sum2+sum3+sum4;
}

#endif

FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
{
BLASLONG i=0;
FLOAT sumf = 0.0;

if (n <= 0 || inc_x <= 0) return(sumf);

if ( inc_x == 1 ) {
sumf = sasum_kernel(n, x);
}
else {

n *= inc_x;
while(i < n) {
sumf += ABS_K(x[i]);
i += inc_x;
}

}
return(sumf);
}
Loading