Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ namespace XARCH {
// avx512/avx2 register length in byte
static constexpr size_t vec_len_avx512 = 64lu;
static constexpr size_t vec_len_avx2 = 32lu;
static constexpr size_t vec_len_neon = 16lu;
// avx512/avx2 register length in float
static constexpr size_t vec_len_f32_avx512 = vec_len_avx512 / sizeof(float);
static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float);

#ifdef HAVE_AVX512F
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
# include <immintrin.h>
#endif


#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/parallel.hpp"
#include "mha_single_token.hpp"
#include "common.hpp"
#include "softmax_kernel.hpp"

#if defined(OPENVINO_ARCH_ARM64)
# include <arm_neon.h>
#endif

namespace ov {
namespace Extensions {
namespace Cpu {
Expand Down Expand Up @@ -53,6 +58,13 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
auto vb = mm256_uni_loadu_ps(src + i);
mm256_uni_storeu_ps(dst + i, vb);
}
#elif defined(OPENVINO_ARCH_ARM64)
int vec_len_f32_neon = 4;
auto _dst = reinterpret_cast<float32_t*>(dst);
for (; i + vec_len_f32_neon <= n; i += vec_len_f32_neon) {
float32x4_t vb1 = vld1q_f32(src + i);
vst1q_f32(_dst + i, vb1);
}
#endif
for (; i < n; i++) {
dst[i] = src[i];
Expand All @@ -78,6 +90,15 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal
v_out = _mm256_fmadd_ps(attn_w_vec_fp32, v_value, v_out);
mm256_uni_storeu_ps(out + i, v_out);
}
#elif defined(OPENVINO_ARCH_ARM64)
float32x4_t attn_w_vec_fp32 = vdupq_n_f32(weight);
auto _v = reinterpret_cast<float32_t *>(v);
for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) {
float32x4_t v_value = vld1q_f32(_v + i);
float32x4_t v_out = vld1q_f32(out + i);
v_out = vmlaq_f32(v_out, attn_w_vec_fp32, v_value);
vst1q_f32(out + i, v_out);
}
#endif
for (; i < S; i++) {
out[i] += weight * v[i];
Expand Down Expand Up @@ -308,6 +329,47 @@ static float sum_q_head(T* a, size_t n) {
vsum0 = _mm256_add_ps(vsum0, vsum2);
hsum(vsum0);
sum = _mm256_cvtss_f32(vsum0);
#elif defined(OPENVINO_ARCH_ARM64)
size_t vec_len_f32_neon = 4;
float32x4_t vsum0 = vdupq_n_f32(0.0f);
float32x4_t vsum1 = vdupq_n_f32(0.0f);
float32x4_t vsum2 = vdupq_n_f32(0.0f);
float32x4_t vsum3 = vdupq_n_f32(0.0f);

for (; i + 4 * vec_len_f32_neon <= n; i += vec_len_f32_neon * 4) {
float32x4_t va0 = vld1q_f32(a + i);
float32x4_t va1 = vld1q_f32(a + i + vec_len_f32_neon);
float32x4_t va2 = vld1q_f32(a + i + vec_len_f32_neon * 2);
float32x4_t va3 = vld1q_f32(a + i + vec_len_f32_neon * 3);

vsum0 = vaddq_f32(va0, vsum0);
vsum1 = vaddq_f32(va1, vsum1);
vsum2 = vaddq_f32(va2, vsum2);
vsum3 = vaddq_f32(va3, vsum3);
}
if (i + 2 * vec_len_f32_neon <= n) {
float32x4_t va0 = vld1q_f32(a + i);
float32x4_t va1 = vld1q_f32(a + i + vec_len_f32_neon);

vsum0 = vaddq_f32(va0, vsum0);
vsum1 = vaddq_f32(va1, vsum1);
i += 2 * vec_len_f32_neon;
}
if (i + vec_len_f32_neon <= n) {
float32x4_t va0 = vld1q_f32(a + i);
vsum0 = vaddq_f32(va0, vsum0);
i += vec_len_f32_neon;
}

vsum0 = vaddq_f32(vsum0, vsum1);
vsum2 = vaddq_f32(vsum2, vsum3);
vsum0 = vaddq_f32(vsum0, vsum2);

float32x2_t sum_low = vget_low_f32(vsum0);
float32x2_t sum_high = vget_high_f32(vsum0);
sum_low = vadd_f32(sum_low, sum_high);
sum_low = vpadd_f32(sum_low, sum_low);
sum = vget_lane_f32(sum_low, 0);
#endif

for (; i < n; i++) {
Expand Down Expand Up @@ -406,7 +468,59 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float*
vsum0 = _mm256_add_ps(vsum0, vsum2);
hsum(vsum0);
sum = _mm256_cvtss_f32(vsum0);

#elif defined(OPENVINO_ARCH_ARM64)
float32x4_t vsum0 = vdupq_n_f32(0.0f);
float32x4_t vsum1 = vdupq_n_f32(0.0f);
float32x4_t vsum2 = vdupq_n_f32(0.0f);
float32x4_t vsum3 = vdupq_n_f32(0.0f);

auto _a = reinterpret_cast<float32_t*>(a);
auto _b = reinterpret_cast<float32_t*>(b);

for (; i + 4 * vec_len_f32_neon <= n; i += vec_len_f32_neon * 4) {
float32x4_t va0 = vld1q_f32(_a + i);
float32x4_t va1 = vld1q_f32(_a + i + vec_len_f32_neon);
float32x4_t va2 = vld1q_f32(_a + i + vec_len_f32_neon * 2);
float32x4_t va3 = vld1q_f32(_a + i + vec_len_f32_neon * 3);

float32x4_t vb0 = vld1q_f32(_b + i);
float32x4_t vb1 = vld1q_f32(_b + i + vec_len_f32_neon);
float32x4_t vb2 = vld1q_f32(_b + i + vec_len_f32_neon * 2);
float32x4_t vb3 = vld1q_f32(_b + i + vec_len_f32_neon * 3);

vsum0 = vmlaq_f32(vsum0, va0, vb0);
vsum1 = vmlaq_f32(vsum1, va1, vb1);
vsum2 = vmlaq_f32(vsum2, va2, vb2);
vsum3 = vmlaq_f32(vsum3, va3, vb3);
}
if (i + 2 * vec_len_f32_neon <= n) {
float32x4_t va0 = vld1q_f32(_a + i);
float32x4_t va1 = vld1q_f32(_a + i + vec_len_f32_neon);

float32x4_t vb0 = vld1q_f32(_b + i);
float32x4_t vb1 = vld1q_f32(_b + i + vec_len_f32_neon);

vsum0 = vmlaq_f32(vsum0, va0, vb0);
vsum1 = vmlaq_f32(vsum1, va1, vb1);
i += 2 * vec_len_f32_neon;
}
if (i + vec_len_f32_neon <= n) {
float32x4_t va0 = vld1q_f32(_a + i);
float32x4_t vb0 = vld1q_f32(_b + i);
vsum0 = vmlaq_f32(vsum0, va0, vb0);
i += vec_len_f32_neon;
}

vsum0 = vaddq_f32(vsum0, vsum1);
vsum2 = vaddq_f32(vsum2, vsum3);
vsum0 = vaddq_f32(vsum0, vsum2);

float32x2_t temp_sum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0));
temp_sum = vpadd_f32(temp_sum, temp_sum);
sum = vget_lane_f32(temp_sum, 0);
#endif

for (; i < n; i++) {
sum += a[i] * b[i];
}
Expand Down Expand Up @@ -593,6 +707,18 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str
}
mm256_uni_storeu_ps(dst + i, result_vec_fp32);
}
#elif defined(OPENVINO_ARCH_ARM64)
auto _dst = reinterpret_cast<float32_t*>(dst);
for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) {
auto* src = temp + i;
auto result_vec_fp32 = vdupq_n_f32(0.0f);
for (size_t m = 0; m < M; m++) {
auto o_vec_fp32 = vld1q_f32(src);
result_vec_fp32 = vaddq_f32(result_vec_fp32, o_vec_fp32);
src += temp_stride;
}
vst1q_f32(_dst + i, result_vec_fp32);
}
#endif
for (; i < S; i++) {
auto* src = temp + i;
Expand Down