@@ -64,7 +64,35 @@ inline void RefreshAccumulator(Accumulator accumulator, Board* board, const int
6464 }
6565}
6666
67- #if defined(__AVX2__ )
67+ #if defined(__AVX512F__ )
68+ const size_t WIDTH = sizeof (__m512i ) / sizeof (int16_t );
69+ const size_t CHUNKS = N_HIDDEN / WIDTH ;
70+
71+ int OutputLayer (Accumulator stm , Accumulator xstm ) {
72+ int result = OUTPUT_BIAS * QUANTIZATION_PRECISION_IN ;
73+
74+ const __m512i zero = _mm512_setzero_si512 ();
75+ __m512i s0 = _mm512_setzero_si512 ();
76+ __m512i s1 = _mm512_setzero_si512 ();
77+
78+ for (size_t j = 0 ; j < CHUNKS ; j ++ ) {
79+ const __m512i ac0 = _mm512_max_epi16 (* (__m512i * )& stm [j * WIDTH ], zero );
80+ const __m512i ac1 = _mm512_max_epi16 (* (__m512i * )& xstm [j * WIDTH ], zero );
81+
82+ s0 = _mm512_add_epi32 (s0 , _mm512_madd_epi16 (ac0 , * (__m512i * )& HIDDEN_WEIGHTS [j * WIDTH ]));
83+ s1 = _mm512_add_epi32 (s1 , _mm512_madd_epi16 (ac1 , * (__m512i * )& HIDDEN_WEIGHTS [j * WIDTH + N_HIDDEN ]));
84+ }
85+
86+ const __m512i r16 = _mm512_add_epi32 (s0 , s1 );
87+ const __m256i r8 = _mm256_add_epi32 (_mm512_castsi512_si256 (r16 ), _mm512_extracti32x8_epi32 (r16 , 1 ));
88+ const __m128i r4 = _mm_add_epi32 (_mm256_castsi256_si128 (r8 ), _mm256_extractf128_si256 (r8 , 1 ));
89+ const __m128i r2 = _mm_add_epi32 (r4 , _mm_srli_si128 (r4 , 8 ));
90+ const __m128i r1 = _mm_add_epi32 (r2 , _mm_srli_si128 (r2 , 4 ));
91+
92+ result += _mm_cvtsi128_si32 (r1 );
93+ return result / QUANTIZATION_PRECISION_IN / QUANTIZATION_PRECISION_OUT ;
94+ }
95+ #elif defined(__AVX2__ )
6896const size_t WIDTH = sizeof (__m256i ) / sizeof (int16_t );
6997const size_t CHUNKS = N_HIDDEN / WIDTH ;
7098
0 commit comments