@@ -106,7 +106,7 @@ bool ReduceKey::operator==(const ReduceKey &rhs) const {
106106
107107// some utility functions
108108static inline bool isFloatCompatible (memory::data_type type) {
109- return memory::data_type::f32 == type || memory::data_type::bf16 == type;
109+ return memory::data_type::f32 == type || memory::data_type::bf16 == type || memory::data_type:: f16 == type ;
110110}
111111
112112template <cpu_isa_t isa>
@@ -207,6 +207,9 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
207207 Xmm xmm_aux3 = Xmm(7 );
208208 Vmm vmm_idx = Vmm(8 );
209209 Vmm vmm_mask = Vmm(9 );
210+ Vmm vmm_dst_fp16 = Vmm(10 );
211+ Ymm ymm_dst_fp16 = Ymm(10 );
212+ Xmm xmm_dst_fp16 = Xmm(10 );
210213
211214 const Xbyak::Opmask k_mask = Xbyak::Opmask(1 );
212215
@@ -570,6 +573,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
570573 }
571574 break ;
572575 case memory::data_type::bf16 :
576+ case memory::data_type::f16 :
573577 case memory::data_type::s8:
574578 case memory::data_type::u8 :
575579 pack_gathered_vector (vmm_src, vmm_idx, offset, jcp_.src_dt );
@@ -597,6 +601,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
597601 mov (reg_tmp_64.cvt16 (), table_idx);
598602 mov (ptr[rsp + i * sizeof (ov::intel_cpu::bfloat16_t )], reg_tmp_64.cvt16 ());
599603 break ;
604+ case memory::data_type::f16 :
605+ mov (reg_tmp_64.cvt16 (), table_idx);
606+ mov (ptr[rsp + i * sizeof (ov::float16)], reg_tmp_64.cvt16 ());
607+ break ;
600608 case memory::data_type::s8:
601609 case memory::data_type::u8 :
602610 mov (reg_tmp_64.cvt8 (), table_idx);
@@ -615,7 +623,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
615623 case memory::data_type::bf16 :
616624 uni_vpmovzxwd (vmm_val, ptr[rsp]);
617625 uni_vpslld (vmm_val, vmm_val, 16 );
618- break ;
626+ break ;
627+ case memory::data_type::f16 :
628+ vcvtph2ps (vmm_val, ptr[rsp]);
629+ break ;
619630 case memory::data_type::s8:
620631 uni_vpmovsxbd (vmm_val, ptr[rsp]);
621632 break ;
@@ -870,6 +881,9 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
870881 uni_vpmovzxwd (vmm_src, op);
871882 uni_vpslld (vmm_src, vmm_src, 16 );
872883 break ;
884+ case memory::data_type::f16 :
885+ vcvtph2ps (vmm_src, op);
886+ break ;
873887 case memory::data_type::s8:
874888 uni_vpmovsxbd (vmm_src, op);
875889 break ;
@@ -894,6 +908,9 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
894908 uni_vpinsrw (xmm_src, xmm_src, op, 0x0 );
895909 uni_vpslld (xmm_src, xmm_src, 16 );
896910 break ;
911+ case memory::data_type::f16 :
912+ vcvtph2ps (xmm_src, op);
913+ break ;
897914 case memory::data_type::s8:
898915 movsx (reg_tmp_32, op);
899916 uni_vmovq (xmm_src, reg_tmp_64);
@@ -928,6 +945,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
928945 uni_vcvtneps2bf16->emit_code ({static_cast <size_t >(vmm_dst.getIdx ())}, {static_cast <size_t >(ymm_dst.getIdx ())});
929946 vmovdqu16 (op, ymm_dst);
930947 break ;
948+ case memory::data_type::f16 :
949+ vcvtps2ph (ymm_dst_fp16, vmm_dst, 0x4 );
950+ vmovdqu16 (op, ymm_dst_fp16);
951+ break ;
931952 case memory::data_type::s8:
932953 if (isa == cpu::x64::avx512_core) {
933954 vpmovsdb (op, vmm_dst);
@@ -976,6 +997,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
976997 uni_vpsrld (xmm_dst, xmm_dst, 16 );
977998 uni_vpextrw (op, xmm_dst, 0x0 );
978999 break ;
1000+ case memory::data_type::f16 :
1001+ vcvtps2ph (xmm_dst_fp16, xmm_dst, 0x4 );
1002+ vmovdqu16 (op, xmm_dst_fp16);
1003+ break ;
9791004 case memory::data_type::s8:
9801005 uni_vpackssdw (xmm_dst, xmm_dst, xmm_dst);
9811006 uni_vpacksswb (xmm_dst, xmm_dst, xmm_dst);
@@ -1214,6 +1239,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
12141239 Vmm vmm_d_weights = Vmm(7 );
12151240 Vmm vmm_d_bias = Vmm(8 );
12161241
1242+ Vmm vmm_dst_fp16 = Vmm(9 );
1243+ Ymm ymm_dst_fp16 = Ymm(9 );
1244+ Xmm xmm_dst_fp16 = Xmm(9 );
1245+
12171246 std::shared_ptr<jit_uni_vcvtneps2bf16> uni_vcvtneps2bf16;
12181247 std::shared_ptr<jit_uni_eltwise_injector_f32<isa>> log_injector;
12191248
@@ -1486,6 +1515,9 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
14861515 uni_vpmovzxwd (vmm_src, op);
14871516 uni_vpslld (vmm_src, vmm_src, 16 );
14881517 break ;
1518+ case memory::data_type::f16 :
1519+ vcvtph2ps (vmm_src, op);
1520+ break ;
14891521 case memory::data_type::s8:
14901522 uni_vpmovsxbd (vmm_src, op);
14911523 break ;
@@ -1510,6 +1542,9 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
15101542 uni_vpinsrw (xmm_src, xmm_src, op, 0x0 );
15111543 uni_vpslld (xmm_src, xmm_src, 16 );
15121544 break ;
1545+ case memory::data_type::f16 :
1546+ vcvtph2ps (xmm_src, op);
1547+ break ;
15131548 case memory::data_type::s8:
15141549 movsx (reg_tmp_32, op);
15151550 uni_vmovq (xmm_src, reg_tmp_64);
@@ -1544,6 +1579,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
15441579 uni_vcvtneps2bf16->emit_code ({static_cast <size_t >(vmm_dst.getIdx ())}, {static_cast <size_t >(ymm_dst.getIdx ())});
15451580 vmovdqu16 (op, ymm_dst);
15461581 break ;
1582+ case memory::data_type::f16 :
1583+ vcvtps2ph (ymm_dst_fp16, vmm_dst, 0x4 );
1584+ vmovdqu16 (op, ymm_dst_fp16);
1585+ break ;
15471586 case memory::data_type::s8:
15481587 if (isa == cpu::x64::avx512_core) {
15491588 vpmovsdb (op, vmm_dst);
@@ -1592,6 +1631,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
15921631 uni_vpsrld (xmm_dst, xmm_dst, 16 );
15931632 uni_vpextrw (op, xmm_dst, 0x0 );
15941633 break ;
1634+ case memory::data_type::f16 :
1635+ vcvtps2ph (xmm_dst_fp16, xmm_dst, 0x4 );
1636+ vmovdqu16 (op, xmm_dst_fp16);
1637+ break ;
15951638 case memory::data_type::s8:
15961639 uni_vpackssdw (xmm_dst, xmm_dst, xmm_dst);
15971640 uni_vpacksswb (xmm_dst, xmm_dst, xmm_dst);
@@ -1806,9 +1849,9 @@ void Reduce::initSupportedPrimitiveDescriptors() {
18061849 jit_mode = canApplyJIT (input_prec, output_prec);
18071850
18081851 if (jit_mode) {
1809- // Since in jit mode we use the output memory as an intermediate accumulator for certain reduce modes, we can't use BF16 output precision due to
1852+ // Since in jit mode we use the output memory as an intermediate accumulator for certain reduce modes, we can't use BF16/FP16 output precision due to
18101853 // the possible accuracy loss. Therefore, for such mods, we will change the output precision to FP32.
1811- if (Precision::BF16 == output_prec) {
1854+ if (Precision::BF16 == output_prec || Precision::FP16 == output_prec ) {
18121855 if (!mayiuse (avx512_core)) {
18131856 output_prec = Precision::FP32;
18141857 } else if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr &&
@@ -2734,6 +2777,9 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
27342777 } else if (output_prec == Precision::BF16) {
27352778 auto out_p = reinterpret_cast <bfloat16_t *>(out_ptr);
27362779 parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = static_cast <bfloat16_t >(1 ); });
2780+ } else if (output_prec == Precision::FP16) {
2781+ auto out_p = reinterpret_cast <ov::float16*>(out_ptr);
2782+ parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = static_cast <ov::float16>(1 ); });
27372783 } else if (output_prec == Precision::U8) {
27382784 auto out_p = reinterpret_cast <uint8_t *>(out_ptr);
27392785 parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = static_cast <uint8_t >(1 ); });
@@ -2752,6 +2798,9 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
27522798 } else if (output_prec == Precision::BF16) {
27532799 auto out_p = reinterpret_cast <bfloat16_t *>(out_ptr);
27542800 parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<bfloat16_t >::lowest (); });
2801+ } else if (output_prec == Precision::FP16) {
2802+ auto out_p = reinterpret_cast <ov::float16*>(out_ptr);
2803+ parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<ov::float16>::lowest (); });
27552804 } else if (output_prec == Precision::U8) {
27562805 auto out_p = reinterpret_cast <uint8_t *>(out_ptr);
27572806 parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<uint8_t >::min (); });
@@ -2770,6 +2819,9 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
27702819 } else if (output_prec == Precision::BF16) {
27712820 auto out_p = reinterpret_cast <bfloat16_t *>(out_ptr);
27722821 parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<bfloat16_t >::max (); });
2822+ } else if (output_prec == Precision::FP16) {
2823+ auto out_p = reinterpret_cast <ov::float16*>(out_ptr);
2824+ parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<ov::float16>::max (); });
27732825 } else if (output_prec == Precision::U8) {
27742826 auto out_p = reinterpret_cast <uint8_t *>(out_ptr);
27752827 parallel_for (dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<uint8_t >::max (); });
@@ -3133,6 +3185,7 @@ std::vector<int> Reduce::update_src_dims() {
31333185bool Reduce::canApplyJIT (const Precision &input_prec, const Precision &output_prec) const {
31343186 static const Precision supportedPrecisions[] = {
31353187 Precision::FP32,
3188+ Precision::FP16,
31363189 Precision::BF16,
31373190 Precision::I32,
31383191 Precision::I8,
0 commit comments