Skip to content

Commit 4c6b4c5

Browse files
committed
[CPU] Reduce node supports fp16 precision
1 parent 22aa219 commit 4c6b4c5

File tree

4 files changed

+140
-32
lines changed

4 files changed

+140
-32
lines changed

src/plugins/intel_cpu/src/graph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ void Graph::Replicate(const CNNNetwork &network) {
312312
for (size_t i = 0; i < childEdges.size(); i++) {
313313
const auto child = childEdges[i]->getChild();
314314
if (child->getOriginalInputPrecisionAtPort(childEdges[i]->getOutputNum()) != Precision::BF16 &&
315+
child->getOriginalInputPrecisionAtPort(childEdges[i]->getOutputNum()) != Precision::FP16 &&
315316
// remove this WA when #78939 is resolved
316317
!hasSubgraphConsumers(child))
317318
child->setOriginalInputPrecisionAtPort(childEdges[i]->getOutputNum(), precToSet);

src/plugins/intel_cpu/src/nodes/reduce.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ bool ReduceKey::operator==(const ReduceKey &rhs) const {
106106

107107
// some utility functions
108108
static inline bool isFloatCompatible(memory::data_type type) {
109-
return memory::data_type::f32 == type || memory::data_type::bf16 == type;
109+
return memory::data_type::f32 == type || memory::data_type::bf16 == type || memory::data_type::f16 == type;
110110
}
111111

112112
template <cpu_isa_t isa>
@@ -207,6 +207,9 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
207207
Xmm xmm_aux3 = Xmm(7);
208208
Vmm vmm_idx = Vmm(8);
209209
Vmm vmm_mask = Vmm(9);
210+
Vmm vmm_dst_fp16 = Vmm(10);
211+
Ymm ymm_dst_fp16 = Ymm(10);
212+
Xmm xmm_dst_fp16 = Xmm(10);
210213

211214
const Xbyak::Opmask k_mask = Xbyak::Opmask(1);
212215

@@ -570,6 +573,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
570573
}
571574
break;
572575
case memory::data_type::bf16:
576+
case memory::data_type::f16:
573577
case memory::data_type::s8:
574578
case memory::data_type::u8:
575579
pack_gathered_vector(vmm_src, vmm_idx, offset, jcp_.src_dt);
@@ -597,6 +601,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
597601
mov(reg_tmp_64.cvt16(), table_idx);
598602
mov(ptr[rsp + i * sizeof(ov::intel_cpu::bfloat16_t)], reg_tmp_64.cvt16());
599603
break;
604+
case memory::data_type::f16:
605+
mov(reg_tmp_64.cvt16(), table_idx);
606+
mov(ptr[rsp + i * sizeof(ov::float16)], reg_tmp_64.cvt16());
607+
break;
600608
case memory::data_type::s8:
601609
case memory::data_type::u8:
602610
mov(reg_tmp_64.cvt8(), table_idx);
@@ -615,7 +623,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
615623
case memory::data_type::bf16:
616624
uni_vpmovzxwd(vmm_val, ptr[rsp]);
617625
uni_vpslld(vmm_val, vmm_val, 16);
618-
break;
626+
break;
627+
case memory::data_type::f16:
628+
vcvtph2ps(vmm_val, ptr[rsp]);
629+
break;
619630
case memory::data_type::s8:
620631
uni_vpmovsxbd(vmm_val, ptr[rsp]);
621632
break;
@@ -870,6 +881,9 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
870881
uni_vpmovzxwd(vmm_src, op);
871882
uni_vpslld(vmm_src, vmm_src, 16);
872883
break;
884+
case memory::data_type::f16:
885+
vcvtph2ps(vmm_src, op);
886+
break;
873887
case memory::data_type::s8:
874888
uni_vpmovsxbd(vmm_src, op);
875889
break;
@@ -894,6 +908,9 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
894908
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
895909
uni_vpslld(xmm_src, xmm_src, 16);
896910
break;
911+
case memory::data_type::f16:
912+
vcvtph2ps(xmm_src, op);
913+
break;
897914
case memory::data_type::s8:
898915
movsx(reg_tmp_32, op);
899916
uni_vmovq(xmm_src, reg_tmp_64);
@@ -928,6 +945,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
928945
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(ymm_dst.getIdx())});
929946
vmovdqu16(op, ymm_dst);
930947
break;
948+
case memory::data_type::f16:
949+
vcvtps2ph(ymm_dst_fp16, vmm_dst, 0x4);
950+
vmovdqu16(op, ymm_dst_fp16);
951+
break;
931952
case memory::data_type::s8:
932953
if (isa == cpu::x64::avx512_core) {
933954
vpmovsdb(op, vmm_dst);
@@ -976,6 +997,10 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
976997
uni_vpsrld(xmm_dst, xmm_dst, 16);
977998
uni_vpextrw(op, xmm_dst, 0x0);
978999
break;
1000+
case memory::data_type::f16:
1001+
vcvtps2ph(xmm_dst_fp16, xmm_dst, 0x4);
1002+
vmovdqu16(op, xmm_dst_fp16);
1003+
break;
9791004
case memory::data_type::s8:
9801005
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
9811006
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
@@ -1214,6 +1239,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
12141239
Vmm vmm_d_weights = Vmm(7);
12151240
Vmm vmm_d_bias = Vmm(8);
12161241

1242+
Vmm vmm_dst_fp16 = Vmm(9);
1243+
Ymm ymm_dst_fp16 = Ymm(9);
1244+
Xmm xmm_dst_fp16 = Xmm(9);
1245+
12171246
std::shared_ptr<jit_uni_vcvtneps2bf16> uni_vcvtneps2bf16;
12181247
std::shared_ptr<jit_uni_eltwise_injector_f32<isa>> log_injector;
12191248

@@ -1486,6 +1515,9 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
14861515
uni_vpmovzxwd(vmm_src, op);
14871516
uni_vpslld(vmm_src, vmm_src, 16);
14881517
break;
1518+
case memory::data_type::f16:
1519+
vcvtph2ps(vmm_src, op);
1520+
break;
14891521
case memory::data_type::s8:
14901522
uni_vpmovsxbd(vmm_src, op);
14911523
break;
@@ -1510,6 +1542,9 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
15101542
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
15111543
uni_vpslld(xmm_src, xmm_src, 16);
15121544
break;
1545+
case memory::data_type::f16:
1546+
vcvtph2ps(xmm_src, op);
1547+
break;
15131548
case memory::data_type::s8:
15141549
movsx(reg_tmp_32, op);
15151550
uni_vmovq(xmm_src, reg_tmp_64);
@@ -1544,6 +1579,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
15441579
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(ymm_dst.getIdx())});
15451580
vmovdqu16(op, ymm_dst);
15461581
break;
1582+
case memory::data_type::f16:
1583+
vcvtps2ph(ymm_dst_fp16, vmm_dst, 0x4);
1584+
vmovdqu16(op, ymm_dst_fp16);
1585+
break;
15471586
case memory::data_type::s8:
15481587
if (isa == cpu::x64::avx512_core) {
15491588
vpmovsdb(op, vmm_dst);
@@ -1592,6 +1631,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
15921631
uni_vpsrld(xmm_dst, xmm_dst, 16);
15931632
uni_vpextrw(op, xmm_dst, 0x0);
15941633
break;
1634+
case memory::data_type::f16:
1635+
vcvtps2ph(xmm_dst_fp16, xmm_dst, 0x4);
1636+
vmovdqu16(op, xmm_dst_fp16);
1637+
break;
15951638
case memory::data_type::s8:
15961639
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
15971640
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
@@ -1806,9 +1849,9 @@ void Reduce::initSupportedPrimitiveDescriptors() {
18061849
jit_mode = canApplyJIT(input_prec, output_prec);
18071850

18081851
if (jit_mode) {
1809-
// Since in jit mode we use the output memory as an intermediate accumulator for certain reduce modes, we can't use BF16 output precision due to
1852+
// Since in jit mode we use the output memory as an intermediate accumulator for certain reduce modes, we can't use BF16/FP16 output precision due to
18101853
// the possible accuracy loss. Therefore, for such mods, we will change the output precision to FP32.
1811-
if (Precision::BF16 == output_prec) {
1854+
if (Precision::BF16 == output_prec || Precision::FP16 == output_prec) {
18121855
if (!mayiuse(avx512_core)) {
18131856
output_prec = Precision::FP32;
18141857
} else if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr &&
@@ -2734,6 +2777,9 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
27342777
} else if (output_prec == Precision::BF16) {
27352778
auto out_p = reinterpret_cast<bfloat16_t*>(out_ptr);
27362779
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = static_cast<bfloat16_t>(1); });
2780+
} else if (output_prec == Precision::FP16) {
2781+
auto out_p = reinterpret_cast<ov::float16*>(out_ptr);
2782+
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = static_cast<ov::float16>(1); });
27372783
} else if (output_prec == Precision::U8) {
27382784
auto out_p = reinterpret_cast<uint8_t *>(out_ptr);
27392785
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = static_cast<uint8_t>(1); });
@@ -2752,6 +2798,9 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
27522798
} else if (output_prec == Precision::BF16) {
27532799
auto out_p = reinterpret_cast<bfloat16_t*>(out_ptr);
27542800
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<bfloat16_t>::lowest(); });
2801+
} else if (output_prec == Precision::FP16) {
2802+
auto out_p = reinterpret_cast<ov::float16*>(out_ptr);
2803+
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<ov::float16>::lowest(); });
27552804
} else if (output_prec == Precision::U8) {
27562805
auto out_p = reinterpret_cast<uint8_t *>(out_ptr);
27572806
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<uint8_t>::min(); });
@@ -2770,6 +2819,9 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
27702819
} else if (output_prec == Precision::BF16) {
27712820
auto out_p = reinterpret_cast<bfloat16_t*>(out_ptr);
27722821
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<bfloat16_t>::max(); });
2822+
} else if (output_prec == Precision::FP16) {
2823+
auto out_p = reinterpret_cast<ov::float16*>(out_ptr);
2824+
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<ov::float16>::max(); });
27732825
} else if (output_prec == Precision::U8) {
27742826
auto out_p = reinterpret_cast<uint8_t *>(out_ptr);
27752827
parallel_for(dst_size / dst_data_size, [&](size_t i) { out_p[i] = std::numeric_limits<uint8_t>::max(); });
@@ -3133,6 +3185,7 @@ std::vector<int> Reduce::update_src_dims() {
31333185
bool Reduce::canApplyJIT(const Precision &input_prec, const Precision &output_prec) const {
31343186
static const Precision supportedPrecisions[] = {
31353187
Precision::FP32,
3188+
Precision::FP16,
31363189
Precision::BF16,
31373190
Precision::I32,
31383191
Precision::I8,

src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ std::vector<std::string> disabledTestPatterns() {
7474
R"(.*OVCompiledModelBaseTest.*(CanGetInputsInfoAndCheck|canSetConfigToCompiledModel).*)",
7575
R"(.*Behavior.*CorrectConfigCheck.*(canSetConfigAndCheckGetConfig|canSetConfigTwiceAndCheckGetConfig).*CPU_BIND_THREAD=YES.*)",
7676
// Issue: 72021 Unreasonable abs_threshold for comparing bf16 results
77-
R"(.*smoke_Reduce.*type=(Prod|Min).*netPRC=(BF|bf)16.*)",
77+
R"(.*smoke_Reduce.*type=(Prod|Min).*INFERENCE_PRECISION_HINT=(BF|bf)16.*)",
7878
// TODO: 56520 Accuracy mismatch
7979
R"(.*ReduceOpsLayerTest.*type=Mean_.*netPRC=(I64|I32).*)",
8080
R"(.*ReduceOpsLayerTest.*type=Mean_.*netPRC=U64.*)",
@@ -237,10 +237,11 @@ std::vector<std::string> disabledTestPatterns() {
237237
#endif
238238

239239
if (!InferenceEngine::with_cpu_x86_avx512_core()) {
240-
// on platforms which do not support bfloat16, we are disabling bf16 tests since there are no bf16 primitives,
240+
// on platforms which do not support bfloat16, we are disabling bf16/f16 tests since there are no bf16/f16 primitives,
241241
// tests are useless on such platforms
242242
retVector.emplace_back(R"(.*(BF|bf)16.*)");
243243
retVector.emplace_back(R"(.*bfloat16.*)");
244+
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
244245
// MatMul in Snippets uses BRGEMM that is supported only on AVX512 platforms
245246
// Disabled Snippets MHA tests as well because MHA pattern contains MatMul
246247
retVector.emplace_back(R"(.*Snippets.*MHA.*)");

0 commit comments

Comments
 (0)