Skip to content

Commit 017950a

Browse files
committed
x64: brgemm bwd_w conv: use ih_block instead of ih for tr_src scratchpad
1 parent 796a600 commit 017950a

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

src/cpu/x64/jit_brgemm_conv_bwd_w.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) {
136136
brgattr.max_top_vpad = 0;
137137
brgattr.max_bottom_vpad = 0;
138138

139-
brgattr.LDA2 = jcp_.tr_iw * jcp_.ih * jcp_.id;
139+
brgattr.LDA2 = jcp_.tr_iw * jcp_.ih_block * jcp_.id;
140140
brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od;
141141
brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw;
142142
brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block
@@ -465,7 +465,7 @@ struct brgemm_convolution_bwd_weights_t::thread_info_t {
465465

466466
size_t tr_src_off(int g, int icb, int id, int ih) const {
467467
const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
468-
const size_t tr_3d_size = tr_row_size * jcp.ih;
468+
const size_t tr_3d_size = tr_row_size * jcp.ih_block;
469469
int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking;
470470
// Aligned to buffer end to use guard elements
471471
return tr_src_buf_number(g, icb) * adj * jcp.tr_src_buf_size
@@ -1023,7 +1023,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
10231023
+ _pd->filter_w_to_src(kw) / jcp.stride_w
10241024
+ (kw % jcp.stride_w) * src_stride_w_shift
10251025
+ (bs_ih_s - ih_s) * jcp.tr_iw * jcp.ic_block
1026-
+ (bs_id_s - id_s) * jcp.ih * jcp.tr_iw * jcp.ic_block;
1026+
+ (bs_id_s - id_s) * jcp.ih_block * jcp.tr_iw * jcp.ic_block;
10271027
const void *ptr_B = ((diff_dst_data_t *)p_dst)
10281028
+ (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block
10291029
+ (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block;
@@ -1045,7 +1045,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
10451045
ti->brg_batch[odb * bs_h + ohb].ptr.A = (char *)ptr_A
10461046
+ ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block
10471047
* jcp.stride_h
1048-
+ odb * jcp.typesize_in * jcp.ih * jcp.tr_iw
1048+
+ odb * jcp.typesize_in * jcp.ih_block * jcp.tr_iw
10491049
* jcp.ic_block * jcp.stride_d;
10501050
ti->brg_batch[odb * bs_h + ohb].ptr.B = (char *)ptr_B
10511051
+ ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block

src/cpu/x64/jit_brgemm_conv_utils.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,16 +2518,6 @@ void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) {
25182518
jcp.nthr_g = nthr_g;
25192519
jcp.nthr_oc_b = nthr_oc_b;
25202520
jcp.nthr_ic_b = nthr_ic_b;
2521-
2522-
// TODO: Optimize memory allocation when threaded on height and depth
2523-
jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
2524-
jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
2525-
jcp.tr_src_buf_count = jcp.global_transpose
2526-
? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
2527-
: jcp.nthr;
2528-
jcp.tr_diff_dst_buf_count = jcp.global_transpose
2529-
? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
2530-
: jcp.nthr;
25312521
}
25322522

25332523
status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
@@ -2758,6 +2748,19 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
27582748
// try to split oh by equal oh blocks
27592749
oh_block_limit = div_up(jcp.oh, div_up(jcp.oh, oh_block_limit));
27602750
jcp.oh_block = utils::saturate(1, jcp.oh, oh_block_limit);
2751+
jcp.ih_block = nstl::min(jcp.ih,
2752+
jcp.stride_h
2753+
* brg_blocking_t::get_inp_size(jcp.ih, jcp.oh_block, jcp.kh,
2754+
jcp.stride_h, jcp.dilate_h));
2755+
// TODO: Optimize memory allocation when threaded on height and depth
2756+
jcp.tr_src_buf_count = jcp.global_transpose
2757+
? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
2758+
: jcp.nthr;
2759+
jcp.tr_diff_dst_buf_count = jcp.global_transpose
2760+
? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
2761+
: jcp.nthr;
2762+
jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih_block * jcp.id;
2763+
jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
27612764

27622765
const int iframe_size = irow_size * jcp.id;
27632766
const int oframe_size = orow_size * jcp.od;

0 commit comments

Comments
 (0)