Skip to content

Commit b76d4ca

Browse files
tczeszunvpirogov
authored andcommitted
cpu: x64: gemm: disable po for unsupported data types & ISAs
1 parent 190a9b2 commit b76d4ca

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

src/cpu/gemm_inner_product_utils.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ pp_kernel_t *pp_kernel_t::create(size_t OC, size_t MB, dim_t dst_mb_stride,
203203
bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d,
204204
const bcast_set_t &enabled_bcast_strategy) {
205205
#if DNNL_X64
206-
static constexpr auto isa_supported
207-
= x64::inner_product_utils::jit_pp_kernel_supported_isa();
206+
const auto isa_supported
207+
= x64::inner_product_utils::get_max_jit_pp_kernel_supported_isa();
208208
using namespace cpu::x64;
209209
if (mayiuse(isa_supported)) {
210210
using namespace x64::injector;
@@ -231,9 +231,8 @@ bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d,
231231
is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4))
232232
&& IMPLICATION(
233233
is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4));
234-
const cpu_isa_t isa = get_max_cpu_isa();
235234
return supported_binary_bcast
236-
&& injector::post_ops_ok(post_ops_ok_args_t(isa,
235+
&& injector::post_ops_ok(post_ops_ok_args_t(isa_supported,
237236
{binary, eltwise, sum}, post_ops, dst_d,
238237
sum_at_pos_0_only, sum_requires_scale_one,
239238
sum_requires_zp_zero, sum_requires_same_params,

src/cpu/x64/jit_gemm_inner_product_utils.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2020-2021 Intel Corporation
2+
* Copyright 2020-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -30,7 +30,15 @@ cpu::inner_product_utils::pp_kernel_t *jit_pp_kernel_create(size_t OC,
3030
data_type_t bias_dt, data_type_t acc_dt, const memory_desc_t *dst_md,
3131
bool skip_sum);
3232

33-
constexpr cpu_isa_t jit_pp_kernel_supported_isa() {
33+
inline cpu_isa_t get_max_jit_pp_kernel_supported_isa() {
34+
#define CASE(isa) \
35+
do { \
36+
if (mayiuse(isa)) return isa; \
37+
} while (false)
38+
CASE(avx512_core_bf16);
39+
CASE(avx512_core);
40+
CASE(avx2);
41+
#undef CASE
3442
return sse41;
3543
}
3644

0 commit comments

Comments
 (0)