@@ -136,7 +136,7 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) {
136136 brgattr.max_top_vpad = 0 ;
137137 brgattr.max_bottom_vpad = 0 ;
138138
139- brgattr.LDA2 = jcp_.tr_iw * jcp_.ih * jcp_.id ;
139+ brgattr.LDA2 = jcp_.tr_iw * jcp_.ih_block * jcp_.id ;
140140 brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od ;
141141 brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw ;
142142 brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block
@@ -465,7 +465,7 @@ struct brgemm_convolution_bwd_weights_t::thread_info_t {
465465
466466 size_t tr_src_off (int g, int icb, int id, int ih) const {
467467 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block ;
468- const size_t tr_3d_size = tr_row_size * jcp.ih ;
468+ const size_t tr_3d_size = tr_row_size * jcp.ih_block ;
469469 int adj = (jcp.global_transpose ) ? 1 : jcp.nb_ic_blocking ;
470470 // Aligned to buffer end to use guard elements
471471 return tr_src_buf_number (g, icb) * adj * jcp.tr_src_buf_size
@@ -1023,7 +1023,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
10231023 + _pd->filter_w_to_src (kw) / jcp.stride_w
10241024 + (kw % jcp.stride_w ) * src_stride_w_shift
10251025 + (bs_ih_s - ih_s) * jcp.tr_iw * jcp.ic_block
1026- + (bs_id_s - id_s) * jcp.ih * jcp.tr_iw * jcp.ic_block ;
1026+ + (bs_id_s - id_s) * jcp.ih_block * jcp.tr_iw * jcp.ic_block ;
10271027 const void *ptr_B = ((diff_dst_data_t *)p_dst)
10281028 + (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block
10291029 + (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block ;
@@ -1045,7 +1045,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
10451045 ti->brg_batch [odb * bs_h + ohb].ptr .A = (char *)ptr_A
10461046 + ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block
10471047 * jcp.stride_h
1048- + odb * jcp.typesize_in * jcp.ih * jcp.tr_iw
1048+ + odb * jcp.typesize_in * jcp.ih_block * jcp.tr_iw
10491049 * jcp.ic_block * jcp.stride_d ;
10501050 ti->brg_batch [odb * bs_h + ohb].ptr .B = (char *)ptr_B
10511051 + ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block
0 commit comments