11/* ******************************************************************************
2- * Copyright 2021-2022 Intel Corporation
3- * Copyright 2022 FUJITSU LIMITED
2+ * Copyright 2021-2024 Intel Corporation
3+ * Copyright 2022-2024 FUJITSU LIMITED
44*
55* Licensed under the Apache License, Version 2.0 (the "License");
66* you may not use this file except in compliance with the License.
@@ -47,9 +47,12 @@ jit_uni_shuffle_kernel_t<isa>::jit_uni_shuffle_kernel_t(
4747template <cpu_isa_t isa>
4848void jit_uni_shuffle_kernel_t <isa>::prepare_mask() {
4949 using namespace data_type ;
50+ using namespace types ;
5051 if (conf_.simd_tail > 0 ) {
51- assert (utils::one_of (conf_.data_type , f32 , s32));
52- assert (conf_.simd_tail < isa_sveLen / sizeof (float ));
52+ /* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, UXTW #1]" is used
53+ to gather data for bf16, simd_tail must be evaluated
54+ with sizeof(unsigned). */
55+ assert (conf_.simd_tail < isa_sveLen / sizeof (uint32_t ));
5356 index (vmm_tmp_.s , 0 , 1 );
5457 cmplt (k_tail_mask_.s , P_ALL_ONE / T_z, vmm_tmp_.s , conf_.simd_tail );
5558 }
@@ -68,13 +71,17 @@ void jit_uni_shuffle_kernel_t<asimd>::prepare_mask() {}
6871template <cpu_isa_t isa>
6972void jit_uni_shuffle_kernel_t <isa>::gather_data(const XReg ®_src_addr,
7073 const int indices_idx, const int data_idx, const bool is_tail) {
71- if (conf_.dt_size == sizeof (float )) {
72- const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
74+ using namespace data_type ;
75+ const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
76+
77+ if (utils::one_of (conf_.data_type , f32 , s32)) {
7378 lsr (TRegS (indices_idx), TRegS (indices_idx), 2 );
7479 ld1w (TRegS (data_idx), mask / T_z,
7580 ptr (reg_src_addr, TRegS (indices_idx), UXTW, 2 ));
76- } else {
77- assert (!" unsupported emu_gather_data" );
81+ } else if (conf_.data_type == bf16 ) {
82+ lsr (TRegS (indices_idx), TRegS (indices_idx), 1 );
83+ ld1h (TRegS (data_idx), mask / T_z,
84+ ptr (reg_src_addr, TRegS (indices_idx), UXTW, 1 ));
7885 }
7986}
8087
@@ -97,21 +104,26 @@ void jit_uni_shuffle_kernel_t<asimd>::gather_data(const XReg &addr,
97104template <cpu_isa_t isa>
98105void jit_uni_shuffle_kernel_t <isa>::store_data(const int data_idx,
99106 const XReg ®_dst_addr, const int offset, const bool is_tail) {
107+ using namespace data_type ;
100108 const auto extend_for_padding
101109 = is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w ;
110+ const PReg &mask = is_tail ? k_tail_mask_ : P_ALL_ONE;
111+
112+ add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
113+
102114 if (extend_for_padding) {
103115 sel (vmm_tmp_.s , k_tail_mask_, TRegS (data_idx), vmm_zero_.s );
104- add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
105- st1w (vmm_tmp_.s , P_ALL_ONE, ptr (X_DEFAULT_ADDR));
116+ if (utils::one_of (conf_.data_type , f32 , s32))
117+ st1w (vmm_tmp_.s , P_ALL_ONE, ptr (X_DEFAULT_ADDR));
118+ else // bf16
119+ st1h (vmm_tmp_.s , P_ALL_ONE, ptr (X_DEFAULT_ADDR));
106120 } else {
107- if (is_tail) {
108- add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
109- st1w (TRegS (data_idx), k_tail_mask_, ptr (X_DEFAULT_ADDR));
110- } else {
111- add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
112- st1w (TRegS (data_idx), P_ALL_ONE, ptr (X_DEFAULT_ADDR));
113- }
121+ if (utils::one_of (conf_.data_type , f32 , s32))
122+ st1w (TRegS (data_idx), mask, ptr (X_DEFAULT_ADDR));
123+ else // bf16
124+ st1h (TRegS (data_idx), mask, ptr (X_DEFAULT_ADDR));
114125 }
126+
115127 append_zero_padding (
116128 reg_dst_, isa_sveLen > 128 ? extend_for_padding : false );
117129}
0 commit comments