Skip to content

Commit 2968c89

Browse files
xuxinzentprimak
authored andcommitted
cpu: x64: dispatch the shape that requires large/small cache to jit on avx2
1 parent 068f850 commit 2968c89

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

src/cpu/x64/jit_brgemm_conv_utils.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,6 +2189,27 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, bool use_inversion,
21892189
!(jcp.ow_block == 0 || jcp.ic_block == 0 || jcp.oc_block == 0),
21902190
VERBOSE_BLOCKING_FAIL);
21912191

2192+
// Dispatch the shape that requires large or small cache to JIT
2193+
// for better performance on AVX2
2194+
// The threshold is empirical
2195+
const size_t w_cache_sz
2196+
= static_cast<size_t>(jcp.src_dsz) * jcp.ic_block * jcp.iwp
2197+
+ jcp.dst_dsz * jcp.ow * jcp.oc_block;
2198+
const size_t wei_cache_sz = static_cast<size_t>(jcp.wei_dsz) * jcp.kd_block
2199+
* jcp.kh_block * jcp.kw_block * jcp.ic_block * jcp.oc_block;
2200+
const size_t nthr_work_amount
2201+
= div_up(static_cast<size_t>(jcp.mb) * jcp.ngroups * jcp.nb_od
2202+
* jcp.nb_oh * jcp.nb_ow * jcp.nb_oc,
2203+
jcp.nthr);
2204+
const bool req_large_cache = jcp.oc >= 256 && jcp.ic >= 256
2205+
&& nstl::max(w_cache_sz, wei_cache_sz) * nthr_work_amount
2206+
>= brg_blocking_t::L2 * 10;
2207+
const bool req_small_cache = jcp.ic <= jcp.acc_simd_w
2208+
&& nstl::max(w_cache_sz, wei_cache_sz) <= 2048;
2209+
VDISPATCH_CONV_IC(
2210+
!((req_large_cache || req_small_cache) && jcp.isa == avx2),
2211+
"Dispatch the shape that requires large/small cache size to jit");
2212+
21922213
// to avoid cache concurrent write access from different threads
21932214
size_t sc_size = sizeof(brgemm_batch_element_t);
21942215
jcp.adjusted_batch_size
@@ -2403,6 +2424,27 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
24032424
VDISPATCH_CONV_IC(
24042425
!(jcp.ic_block == 0 || jcp.oc_block == 0), VERBOSE_BLOCKING_FAIL);
24052426

2427+
// Dispatch the shape that requires large or small cache to JIT
2428+
// for better performance on AVX2
2429+
// The threshold is empirical
2430+
const size_t w_cache_sz
2431+
= static_cast<size_t>(jcp.src_dsz) * jcp.ic_block * jcp.iwp
2432+
+ jcp.dst_dsz * jcp.ow * jcp.oc_block;
2433+
const size_t wei_cache_sz = static_cast<size_t>(jcp.wei_dsz) * jcp.kd_block
2434+
* jcp.kh_block * jcp.kw_block * jcp.ic_block * jcp.oc_block;
2435+
const size_t nthr_work_amount
2436+
= div_up(static_cast<size_t>(jcp.mb) * jcp.ngroups * jcp.nb_od
2437+
* jcp.nb_oh * jcp.nb_ow * jcp.nb_oc,
2438+
jcp.nthr);
2439+
const bool req_large_cache = jcp.oc >= 256 && jcp.ic >= 256
2440+
&& nstl::max(w_cache_sz, wei_cache_sz) * nthr_work_amount
2441+
>= brg_blocking_t::L2 * 10;
2442+
const bool req_small_cache = jcp.ic <= jcp.acc_simd_w
2443+
&& nstl::max(w_cache_sz, wei_cache_sz) <= 2048;
2444+
VDISPATCH_CONV_IC(
2445+
!((req_large_cache || req_small_cache) && jcp.isa == avx2),
2446+
"Dispatch the shapes that requie large/small cache size to jit");
2447+
24062448
// Configure matrix sizes
24072449

24082450
if (best_brgb.is_os_blocking) {

0 commit comments

Comments
 (0)