@@ -871,6 +871,7 @@ struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
871871 int32_t *src_zp_comp_ptr;
872872 int32_t *dst_zp_vals;
873873 int32_t *s8s8_comp_ptr;
874+ const float *dst_scales {nullptr };
874875};
875876
876877template <cpu_isa_t isa, bool use_inversion>
@@ -884,6 +885,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
884885
885886 DEFINE_ARG_SCALES_BUFFER (src_scales, DNNL_ARG_SRC);
886887 DEFINE_ARG_SCALES_BUFFER (wei_scales, DNNL_ARG_WEIGHTS);
888+ DEFINE_ARG_SCALES_BUFFER (dst_scales, DNNL_ARG_DST);
887889
888890 const float *oscales = precompute_scales (ctx.get_scratchpad_grantor (),
889891 src_scales, wei_scales, _pd->OC (), _pd->attr ());
@@ -1012,6 +1014,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
10121014 btc.src_zp_comp_ptr
10131015 = jcp.src_zero_point ? src_zp_comp_base : nullptr ;
10141016 btc.s8s8_comp_ptr = jcp.s8s8_avx512 ? s8s8_comp_base : nullptr ;
1017+ btc.dst_scales = dst_scales;
10151018
10161019 if (jcp.exec_type == exec_trans && (last_n != n || last_g != g)) {
10171020 if (!jcp.copy_block_only )
@@ -1133,7 +1136,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
11331136 int kd_l, int kh_l, const void *post_ops_binary_rhs_arg_vec,
11341137 const float *oscales, int32_t src_zp_vals, int32_t *src_zp_ptr,
11351138 int32_t *dst_zp_ptr, int32_t *s8s8_compensation, bool maybe_do_init,
1136- bool do_postwork, bool do_post_comp) const {
1139+ bool do_postwork, bool do_post_comp, const float *dst_scales ) const {
11371140
11381141 const auto _pd = pd ();
11391142 const auto &jcp = _pd->jcp_ ;
@@ -1160,6 +1163,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
11601163 p.dst_orig = dst;
11611164 p.c_zp_values = dst_zp_ptr;
11621165 p.a_comp_val = src_zp_vals;
1166+ p.ptr_dst_scales = (void *)dst_scales;
11631167 }
11641168
11651169 auto call_outwork_ker = [&](bool is_postwork, bool has_postcomp,
@@ -1246,7 +1250,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
12461250 static_cast <size_t >(g_oc), 0 , btc.brgemm_ctx .dst , 0 ,
12471251 static_cast <void *>(src_zp_ptr), nullptr ,
12481252 static_cast <void *>(dst_zp_ptr), false , src_zp_vals,
1249- do_only_comp, do_only_pass_comp};
1253+ do_only_comp, do_only_pass_comp, btc. dst_scales };
12501254
12511255 void *scratch = is_amx ? static_cast <void *>(btc.wsp_tile )
12521256 : static_cast <void *>(s8s8_comp);
@@ -1581,7 +1585,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
15811585 g_oc, is_oc_tail, ow_b, ow_e, kd_l, kh_l,
15821586 post_ops_binary_rhs_arg_vec.data (), btc.oscales ,
15831587 btc.src_zp_vals , btc.src_zp_comp_ptr , btc.dst_zp_vals ,
1584- btc.s8s8_comp_ptr , do_init, do_postwork, false );
1588+ btc.s8s8_comp_ptr , do_init, do_postwork, false , btc. dst_scales );
15851589 };
15861590
15871591 if (kd_f > kd_s && kh_f > kh_s && kw_f > kw_s) {
@@ -1636,7 +1640,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
16361640 g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
16371641 post_ops_binary_rhs_arg_vec.data (), btc.oscales ,
16381642 btc.src_zp_vals , btc.src_zp_comp_ptr , btc.dst_zp_vals ,
1639- btc.s8s8_comp_ptr , do_init, do_postwork, false );
1643+ btc.s8s8_comp_ptr , do_init, do_postwork, false , btc. dst_scales );
16401644 }
16411645}
16421646
@@ -1792,7 +1796,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
17921796 g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
17931797 post_ops_binary_rhs_arg_vec.data (), btc.oscales ,
17941798 btc.src_zp_vals , btc.src_zp_comp_ptr , btc.dst_zp_vals ,
1795- btc.s8s8_comp_ptr , do_init, do_postwork, false );
1799+ btc.s8s8_comp_ptr , do_init, do_postwork, false , btc. dst_scales );
17961800 }
17971801}
17981802
@@ -1944,7 +1948,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(
19441948 g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
19451949 post_ops_binary_rhs_arg_vec.data (), btc.oscales ,
19461950 btc.src_zp_vals , btc.src_zp_comp_ptr , btc.dst_zp_vals ,
1947- btc.s8s8_comp_ptr , do_init, do_postwork, false );
1951+ btc.s8s8_comp_ptr , do_init, do_postwork, false , btc. dst_scales );
19481952 }
19491953}
19501954
0 commit comments