Skip to content

Commit 1eab005

Browse files
tczeszuntprimak
authored andcommitted
x64: brgconv: remove unnecessary batchsize loops
1 parent fbe5b97 commit 1eab005

File tree

2 files changed

+87
-112
lines changed

2 files changed

+87
-112
lines changed

src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp

Lines changed: 84 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,7 @@ status_t brgemm_convolution_bwd_strided_t<isa, is_deconv>::pd_t::init(
130130

131131
const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail);
132132

133-
batchsizes.resize(jcp_.max_batch + 1);
134-
for (int i = 0; i <= jcp_.max_batch; i++)
135-
batchsizes[i] = -1;
136-
137-
first_bs = 0;
138-
bs_c = 0;
139-
140-
batchsizes[jcp_.max_batch] = bs_c;
141-
first_bs = jcp_.max_batch;
142-
bs_c++;
143-
144-
brgs_sz_ = bs_c * adj_M * 2 * 2 * 2;
133+
brgs_sz_ = adj_M * 2 * 2 * 2;
145134
brgs_ = std::make_shared<brgemm_containers::brgemm_desc_container_t>();
146135
brgs_->resize(brgs_sz_);
147136

@@ -159,86 +148,79 @@ status_t brgemm_convolution_bwd_strided_t<isa, is_deconv>::pd_t::init(
159148
if (one_of(jcp_.exec_type, exec_trans, exec_vpad) && vM != jcp_.M
160149
&& vM != jcp_.M_tail)
161150
continue;
162-
for (int bs = 0; bs <= jcp_.max_batch; bs++) {
163-
if (batchsizes[bs] == -1) continue;
164-
for_(int i_init = 0; i_init < 2; i_init++)
165-
for_(int i_N = 0; i_N < 2; i_N++)
166-
for (int i_K = 0; i_K < 2; i_K++) {
167-
auto vbeta = (i_init) ? 0 : beta;
168-
auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
169-
auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
170-
auto vbrgM = jcp_.use_M_mask
171-
? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail)
172-
: vM;
173-
auto brg_idx = get_brg_idx(bs, i, i_init, i_N, i_K);
174-
// if brgemm_t already created then skip this iteration
175-
if ((*brgs_)[brg_idx] != nullptr) continue;
176-
brgemm_t brg;
177-
if (vN == 0 || vK == 0) continue;
178-
brgemm_strides_t brg_strides;
179-
brg_strides.stride_a = jcp_.brg_stride_a;
180-
brg_strides.stride_b = jcp_.brg_stride_b;
181-
brg.req_cal_comp_pads = jcp_.req_brg_comp_pad;
182-
brg.req_comp_pads_with_bcast
183-
= jcp_.req_cal_comp_pad && jcp_.exec_type == exec_trans;
184-
const auto strides_ptr = (jcp_.brg_type == brgemm_strd)
185-
? &brg_strides
186-
: nullptr;
187-
CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, diff_dst_type,
188-
wei_type, false, false, brgemm_row_major, alpha, vbeta,
189-
jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK,
190-
strides_ptr));
191-
192-
brgemm_attr_t brgattr;
193-
brgattr.use_uker = jcp_.use_uker;
194-
brgattr.use_interleave_stores = jcp_.use_interleave_stores;
195-
brgattr.hint_prefetching = jcp_.hint_prefetching;
196-
brgattr.max_bs = bs;
197-
brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
198-
? brgemm_bd_loop_innermost
199-
: brgemm_ld_loop_innermost;
200-
if (jcp_.amx_tile_load_xx) {
201-
// assuming 2x2 decomposition in amx brgemm kernel
202-
// and overlap of input by kw
203-
const auto bd_blocking = 2 * jcp_.amx_h;
204-
const auto ld_blocking = 2 * 16;
205-
brgattr.hint_expected_A_size = bd_blocking * jcp_.K
206-
* jcp_.kd_block * jcp_.kh_block;
207-
brgattr.hint_expected_B_size = ld_blocking * jcp_.K
208-
* jcp_.kd_block * jcp_.kh_block * jcp_.kw_block;
209-
brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
210-
} else {
211-
brgattr.hint_expected_A_size = 0;
212-
brgattr.hint_expected_B_size = 0;
213-
brgattr.hint_expected_C_size = 0;
214-
}
151+
for_(int i_init = 0; i_init < 2; i_init++)
152+
for_(int i_N = 0; i_N < 2; i_N++)
153+
for (int i_K = 0; i_K < 2; i_K++) {
154+
auto vbeta = (i_init) ? 0 : beta;
155+
auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
156+
auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
157+
auto vbrgM = jcp_.use_M_mask
158+
? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail)
159+
: vM;
160+
auto brg_idx = get_brg_idx(jcp_.max_batch, i, i_init, i_N, i_K);
161+
// if brgemm_t already created then skip this iteration
162+
if ((*brgs_)[brg_idx] != nullptr) continue;
163+
brgemm_t brg;
164+
if (vN == 0 || vK == 0) continue;
165+
brgemm_strides_t brg_strides;
166+
brg_strides.stride_a = jcp_.brg_stride_a;
167+
brg_strides.stride_b = jcp_.brg_stride_b;
168+
brg.req_cal_comp_pads = jcp_.req_brg_comp_pad;
169+
brg.req_comp_pads_with_bcast
170+
= jcp_.req_cal_comp_pad && jcp_.exec_type == exec_trans;
171+
const auto strides_ptr
172+
= (jcp_.brg_type == brgemm_strd) ? &brg_strides : nullptr;
173+
CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, diff_dst_type,
174+
wei_type, false, false, brgemm_row_major, alpha, vbeta,
175+
jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK, strides_ptr));
176+
177+
brgemm_attr_t brgattr;
178+
brgattr.use_uker = jcp_.use_uker;
179+
brgattr.use_interleave_stores = jcp_.use_interleave_stores;
180+
brgattr.hint_prefetching = jcp_.hint_prefetching;
181+
brgattr.max_bs = jcp_.max_batch;
182+
brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
183+
? brgemm_bd_loop_innermost
184+
: brgemm_ld_loop_innermost;
185+
if (jcp_.amx_tile_load_xx) {
186+
// assuming 2x2 decomposition in amx brgemm kernel
187+
// and overlap of input by kw
188+
const auto bd_blocking = 2 * jcp_.amx_h;
189+
const auto ld_blocking = 2 * 16;
190+
brgattr.hint_expected_A_size
191+
= bd_blocking * jcp_.K * jcp_.kd_block * jcp_.kh_block;
192+
brgattr.hint_expected_B_size = ld_blocking * jcp_.K
193+
* jcp_.kd_block * jcp_.kh_block * jcp_.kw_block;
194+
brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
195+
} else {
196+
brgattr.hint_expected_A_size = 0;
197+
brgattr.hint_expected_B_size = 0;
198+
brgattr.hint_expected_C_size = 0;
199+
}
215200

216-
brgattr.wary_tail_read = false;
217-
// use_M_mask is always 0 for brgemm_convolution_bwd_strided_t
218-
brgattr.bd_mask = nullptr;
219-
brgattr.bd_mask_level = jcp_.use_M_mask;
220-
221-
if (is_amx) {
222-
brgattr.max_top_vpad = 0;
223-
brgattr.max_bottom_vpad = 0;
224-
} else {
225-
brgattr.max_top_vpad = jcp_.max_vpad;
226-
brgattr.max_bottom_vpad = jcp_.max_vpad;
227-
}
228-
brgattr.generate_skip_accumulation = true;
229-
CHECK(brgemm_desc_set_attr(&brg, brgattr));
230-
231-
auto LDD = jcp_.stride_w * jcp_.ic_without_padding;
232-
brg.with_sum = with_sum;
233-
brg.with_weights_scale_adjust
234-
= jcp_.scale_adjust_factor != 1.0f;
235-
CHECK(brgemm_desc_set_postops(
236-
&brg, attr(), &diff_src_md_, LDD, jcp_.bia_dt));
237-
jcp_.amx_buf_size_per_thread
238-
= nstl::max(brg.get_wsp_buffer_size(),
239-
jcp_.amx_buf_size_per_thread);
240-
brgs_->insert(brg_idx, brg);
201+
brgattr.wary_tail_read = false;
202+
// use_M_mask is always 0 for brgemm_convolution_bwd_strided_t
203+
brgattr.bd_mask = nullptr;
204+
brgattr.bd_mask_level = jcp_.use_M_mask;
205+
206+
if (is_amx) {
207+
brgattr.max_top_vpad = 0;
208+
brgattr.max_bottom_vpad = 0;
209+
} else {
210+
brgattr.max_top_vpad = jcp_.max_vpad;
211+
brgattr.max_bottom_vpad = jcp_.max_vpad;
241212
}
213+
brgattr.generate_skip_accumulation = true;
214+
CHECK(brgemm_desc_set_attr(&brg, brgattr));
215+
216+
auto LDD = jcp_.stride_w * jcp_.ic_without_padding;
217+
brg.with_sum = with_sum;
218+
brg.with_weights_scale_adjust = jcp_.scale_adjust_factor != 1.0f;
219+
CHECK(brgemm_desc_set_postops(
220+
&brg, attr(), &diff_src_md_, LDD, jcp_.bia_dt));
221+
jcp_.amx_buf_size_per_thread = nstl::max(
222+
brg.get_wsp_buffer_size(), jcp_.amx_buf_size_per_thread);
223+
brgs_->insert(brg_idx, brg);
242224
}
243225
}
244226

@@ -410,17 +392,13 @@ void brgemm_convolution_bwd_strided_t<isa, is_deconv>::create_kernels() {
410392
: 0;
411393
int i_init_end = 2;
412394

413-
for (int bs = 0; bs <= jcp.max_batch; bs++) {
414-
if (_pd->batchsizes[bs] == -1) continue;
415-
416-
for_(int i_N = N_begin; i_N < N_end; i_N++)
417-
for_(int i_M = M_begin; i_M < M_end; i_M++)
418-
for_(int i_init = i_init_begin; i_init < i_init_end; i_init++)
419-
for (int i_K = K_begin; i_K < K_end; i_K++) {
420-
auto M = (i_M) ? jcp.M_tail : jcp.M;
421-
if (M <= 0) continue;
422-
add_brg_kernel(bs, M, i_N, i_K, i_init);
423-
}
395+
for_(int i_N = N_begin; i_N < N_end; i_N++)
396+
for_(int i_M = M_begin; i_M < M_end; i_M++)
397+
for_(int i_init = i_init_begin; i_init < i_init_end; i_init++)
398+
for (int i_K = K_begin; i_K < K_end; i_K++) {
399+
auto M = (i_M) ? jcp.M_tail : jcp.M;
400+
if (M <= 0) continue;
401+
add_brg_kernel(jcp.max_batch, M, i_N, i_K, i_init);
424402
}
425403

426404
if (jcp.exec_type == exec_base) {
@@ -444,14 +422,11 @@ void brgemm_convolution_bwd_strided_t<isa, is_deconv>::create_kernels() {
444422
for (int kw = kw_s; kw < kw_f; kw++) {
445423
get_iw_range(iw_str, iw, kw, iw_s, M_without_overflow);
446424
if (M_without_overflow <= 0) continue;
447-
for (int bs = 0; bs <= jcp.max_batch; bs++) {
448-
if (_pd->batchsizes[bs] == -1) continue;
449-
for_(int i_init = 0; i_init < 2; i_init++)
450-
for_(int i_N = 0; i_N < 2; i_N++)
451-
for (int i_K = 0; i_K < 2; i_K++) {
452-
add_brg_kernel(
453-
bs, M_without_overflow, i_N, i_K, i_init);
454-
}
425+
for_(int i_init = 0; i_init < 2; i_init++)
426+
for_(int i_N = 0; i_N < 2; i_N++)
427+
for (int i_K = 0; i_K < 2; i_K++) {
428+
add_brg_kernel(jcp.max_batch, M_without_overflow, i_N, i_K,
429+
i_init);
455430
}
456431

457432
bool is_iw_tail = (jcp.iw - iw < jcp.iw_block);

src/cpu/x64/jit_brgemm_conv_bwd_strided.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ struct brgemm_convolution_bwd_strided_t : public primitive_t {
5959
std::shared_ptr<brgemm_containers::brgemm_desc_container_t> brgs_;
6060

6161
jit_brgemm_conv_conf_t jcp_;
62-
// batch sizes info for unrolled kernels
63-
int bs_c, first_bs;
64-
std::vector<int> batchsizes;
62+
// batch size info
63+
const int first_bs = 0;
6564
int get_brg_idx(int bs, int m, bool do_initialization, bool is_N_tail,
6665
bool is_K_tail) const {
66+
const int bs_c = 1;
6767
auto bs_idx = 0;
6868
return (((m * bs_c + bs_idx) * 2
6969
+ static_cast<int>(do_initialization))

0 commit comments

Comments
 (0)