@@ -2189,6 +2189,27 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, bool use_inversion,
21892189 !(jcp.ow_block == 0 || jcp.ic_block == 0 || jcp.oc_block == 0 ),
21902190 VERBOSE_BLOCKING_FAIL);
21912191
2192+ // Dispatch the shape that requires large or small cache to JIT
2193+ // for better performance on AVX2
2194+ // The threshold is empirical
2195+ const size_t w_cache_sz
2196+ = static_cast <size_t >(jcp.src_dsz ) * jcp.ic_block * jcp.iwp
2197+ + jcp.dst_dsz * jcp.ow * jcp.oc_block ;
2198+ const size_t wei_cache_sz = static_cast <size_t >(jcp.wei_dsz ) * jcp.kd_block
2199+ * jcp.kh_block * jcp.kw_block * jcp.ic_block * jcp.oc_block ;
2200+ const size_t nthr_work_amount
2201+ = div_up (static_cast <size_t >(jcp.mb ) * jcp.ngroups * jcp.nb_od
2202+ * jcp.nb_oh * jcp.nb_ow * jcp.nb_oc ,
2203+ jcp.nthr );
2204+ const bool req_large_cache = jcp.oc >= 256 && jcp.ic >= 256
2205+ && nstl::max (w_cache_sz, wei_cache_sz) * nthr_work_amount
2206+ >= brg_blocking_t ::L2 * 10 ;
2207+ const bool req_small_cache = jcp.ic <= jcp.acc_simd_w
2208+ && nstl::max (w_cache_sz, wei_cache_sz) <= 2048 ;
2209+ VDISPATCH_CONV_IC (
2210+ !((req_large_cache || req_small_cache) && jcp.isa == avx2),
2211+ " Dispatch the shape that requires large/small cache size to jit" );
2212+
21922213 // to avoid cache concurrent write access from different threads
21932214 size_t sc_size = sizeof (brgemm_batch_element_t );
21942215 jcp.adjusted_batch_size
@@ -2403,6 +2424,27 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
24032424 VDISPATCH_CONV_IC (
24042425 !(jcp.ic_block == 0 || jcp.oc_block == 0 ), VERBOSE_BLOCKING_FAIL);
24052426
2427+ // Dispatch the shape that requires large or small cache to JIT
2428+ // for better performance on AVX2
2429+ // The threshold is empirical
2430+ const size_t w_cache_sz
2431+ = static_cast <size_t >(jcp.src_dsz ) * jcp.ic_block * jcp.iwp
2432+ + jcp.dst_dsz * jcp.ow * jcp.oc_block ;
2433+ const size_t wei_cache_sz = static_cast <size_t >(jcp.wei_dsz ) * jcp.kd_block
2434+ * jcp.kh_block * jcp.kw_block * jcp.ic_block * jcp.oc_block ;
2435+ const size_t nthr_work_amount
2436+ = div_up (static_cast <size_t >(jcp.mb ) * jcp.ngroups * jcp.nb_od
2437+ * jcp.nb_oh * jcp.nb_ow * jcp.nb_oc ,
2438+ jcp.nthr );
2439+ const bool req_large_cache = jcp.oc >= 256 && jcp.ic >= 256
2440+ && nstl::max (w_cache_sz, wei_cache_sz) * nthr_work_amount
2441+ >= brg_blocking_t ::L2 * 10 ;
2442+ const bool req_small_cache = jcp.ic <= jcp.acc_simd_w
2443+ && nstl::max (w_cache_sz, wei_cache_sz) <= 2048 ;
2444+ VDISPATCH_CONV_IC (
2445+ !((req_large_cache || req_small_cache) && jcp.isa == avx2),
2446+ " Dispatch the shapes that requie large/small cache size to jit" );
2447+
24062448 // Configure matrix sizes
24072449
24082450 if (best_brgb.is_os_blocking ) {
0 commit comments