Skip to content

Commit 27845b8

Browse files
committed
x64: fix scales and bias order in brgemm conv post-ops kernel
test coverage extended
1 parent 8bb651c commit 27845b8

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

src/cpu/x64/jit_brgemm_post_ops.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,18 +685,6 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
685685

686686
if (req_comp) maybe_apply_comp(m_block, n_block, tail);
687687

688-
if (brg.beta != 0 && jcp.with_bias) {
689-
for (int n = 0; n < n_block; n++) {
690-
auto vmm_bias = vmm_tmp(0);
691-
auto bias_addr = ptr[aux_reg_bias
692-
+ bia_typesize_ * (n * brg.ld_block)];
693-
cvt2ps(bia_dt_, vmm_bias, bias_addr, tail, false, k_mask);
694-
for (int m = 0; m < m_block; m++) {
695-
vaddps(vector(m, n), vmm_bias);
696-
}
697-
}
698-
}
699-
700688
if (brg.beta != 0) {
701689
for_(int m = 0; m < m_block; m++)
702690
for (int n = 0; n < n_block; n++) {
@@ -714,6 +702,18 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
714702
}
715703
}
716704

705+
if (brg.beta != 0 && jcp.with_bias) {
706+
for (int n = 0; n < n_block; n++) {
707+
auto vmm_bias = vmm_tmp(0);
708+
auto bias_addr = ptr[aux_reg_bias
709+
+ bia_typesize_ * (n * brg.ld_block)];
710+
cvt2ps(bia_dt_, vmm_bias, bias_addr, tail, false, k_mask);
711+
for (int m = 0; m < m_block; m++) {
712+
vaddps(vector(m, n), vmm_bias);
713+
}
714+
}
715+
}
716+
717717
if (postops_injector_) inject_attr_postops(m_block, n_block, tail);
718718

719719
if (brg.beta != 0 && brg.zp_type_c != brgemm_broadcast_t::none) {

tests/benchdnn/inputs/conv/harness_conv_attrs_int8

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
--mb=2
2828
--skip-impl=ref,x64:gemm # ! test jit version only
2929
--dir=FWD_B
30-
--attr-scales=src:common:0.25*+wei:per_oc:0.5*+dst:common:2.25* --attr-post-ops=sum:1.5:2+relu
30+
--attr-scales=src:common:0.25*+wei:per_oc:0.5*,src:common:0.25*+wei:per_oc:0.5*+dst:common:2.25*
31+
--attr-post-ops=sum:1.5:2+relu
3132
--cfg=s8s8f32,s8s8u8,u8s8f32,u8s8u8 --batch=shapes_tails
33+
--cfg=s8s8u8,u8s8u8 --batch=shapes_basic
3234

3335
# i8 conv + f32 leaky relu
3436
--reset --dir=FWD_B --mb=2

0 commit comments

Comments
 (0)