Skip to content

Commit 38319f1

Browse files
akharitotprimak
authored andcommitted
x64: support dst scales in brgemm-based implementations
1 parent 18de927 commit 38319f1

17 files changed

+119
-47
lines changed

src/cpu/x64/jit_brdgmm_dw_conv.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2022 Intel Corporation
2+
* Copyright 2021-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -195,7 +195,8 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
195195
|| !wei_scales.has_default_values();
196196
jcp.is_oc_scale = wei_scales.mask_ != 0;
197197

198-
const bool scales_ok = attr_scales_ok({DNNL_ARG_SRC, DNNL_ARG_WEIGHTS});
198+
const bool scales_ok
199+
= attr_scales_ok({DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST});
199200
if (!scales_ok) return status::unimplemented;
200201

201202
// strd is only feasible for 1D (i.e., height dim is one)
@@ -388,6 +389,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
388389

389390
DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
390391
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
392+
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
391393

392394
const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
393395
src_scales, wei_scales, pd()->OC(), pd()->attr());
@@ -516,6 +518,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
516518
post_ops_data.bias = bias + ch * jcp.bia_dsz;
517519
post_ops_data.scales = &oscales[jcp.is_oc_scale * ch];
518520
post_ops_data.oc_logical_off = ch;
521+
post_ops_data.dst_scales = dst_scales;
519522
brgemm_kernel_execute_postops(kernel, bs, ptr_A, ptr_B,
520523
brg_batch, ptr_C, ptr_C, post_ops_data,
521524
nullptr /*scratch*/);

src/cpu/x64/jit_brgemm_1x1_conv.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
323323
char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
324324
int od, int oh, int ow, int icc, int *last_palette_idx,
325325
const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp,
326-
int32_t *dst_zp_vals, int32_t *s8s8_compensation) const {
326+
int32_t *dst_zp_vals, int32_t *s8s8_compensation,
327+
const float *dst_scales) const {
327328

328329
const memory_desc_wrapper src_d(pd()->src_md());
329330
const memory_desc_wrapper weights_d(pd()->weights_md());
@@ -428,7 +429,8 @@ void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
428429
post_ops_binary_rhs_arg_vec.data(),
429430
static_cast<size_t>(g_oc), 0, dst, 0,
430431
static_cast<void *>(src_zp_comp_ptr), nullptr,
431-
static_cast<void *>(dst_zp_vals), false, src_zp_vals};
432+
static_cast<void *>(dst_zp_vals), false, src_zp_vals, false,
433+
false, dst_scales};
432434

433435
void *scratch = is_amx ? static_cast<void *>(wsp_tile)
434436
: static_cast<void *>(s8s8_comp_ptr);
@@ -473,6 +475,7 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
473475

474476
DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
475477
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
478+
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
476479

477480
const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
478481
src_scales, wei_scales, pd()->OC(), pd()->attr());
@@ -553,7 +556,8 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
553556
exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, \
554557
inp_buffer_sp, g, n, ocb, od, oh, ow, icc, \
555558
&last_palette_idx, oscales, src_zero_point, \
556-
zp_compensation, dst_zp_vals, s8s8_compensation); \
559+
zp_compensation, dst_zp_vals, s8s8_compensation, \
560+
dst_scales); \
557561
} \
558562
} \
559563
last_n = n; \
@@ -595,7 +599,7 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
595599
exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, nullptr, g, n, \
596600
ocb, od, oh, ow, icc, &last_palette_idx, oscales, \
597601
src_zero_point, zp_compensation, dst_zp_vals, \
598-
s8s8_compensation); \
602+
s8s8_compensation, dst_scales); \
599603
} \
600604
nd_iterator_step(__VA_ARGS__); \
601605
} \

src/cpu/x64/jit_brgemm_1x1_conv.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2022 Intel Corporation
2+
* Copyright 2021-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -64,7 +64,8 @@ struct brgemm_1x1_convolution_fwd_t : public primitive_t {
6464

6565
protected:
6666
bool arg_scales_ok() const {
67-
std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
67+
std::vector<int> supported_args
68+
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
6869
return attr_scales_ok(supported_args);
6970
}
7071
bool zero_points_ok() const {
@@ -121,7 +122,8 @@ struct brgemm_1x1_convolution_fwd_t : public primitive_t {
121122
char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
122123
int od, int oh, int ow, int icc, int *last_brg_idx,
123124
const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp,
124-
int32_t *dst_zp_vals, int32_t *s8s8_compensation) const;
125+
int32_t *dst_zp_vals, int32_t *s8s8_compensation,
126+
const float *dst_scales) const;
125127
status_t execute_forward_all(const exec_ctx_t &ctx) const;
126128
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
127129

src/cpu/x64/jit_brgemm_conv.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,7 @@ struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
871871
int32_t *src_zp_comp_ptr;
872872
int32_t *dst_zp_vals;
873873
int32_t *s8s8_comp_ptr;
874+
const float *dst_scales {nullptr};
874875
};
875876

876877
template <cpu_isa_t isa, bool use_inversion>
@@ -884,6 +885,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
884885

885886
DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
886887
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
888+
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
887889

888890
const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
889891
src_scales, wei_scales, _pd->OC(), _pd->attr());
@@ -1012,6 +1014,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
10121014
btc.src_zp_comp_ptr
10131015
= jcp.src_zero_point ? src_zp_comp_base : nullptr;
10141016
btc.s8s8_comp_ptr = jcp.s8s8_avx512 ? s8s8_comp_base : nullptr;
1017+
btc.dst_scales = dst_scales;
10151018

10161019
if (jcp.exec_type == exec_trans && (last_n != n || last_g != g)) {
10171020
if (!jcp.copy_block_only)
@@ -1133,7 +1136,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
11331136
int kd_l, int kh_l, const void *post_ops_binary_rhs_arg_vec,
11341137
const float *oscales, int32_t src_zp_vals, int32_t *src_zp_ptr,
11351138
int32_t *dst_zp_ptr, int32_t *s8s8_compensation, bool maybe_do_init,
1136-
bool do_postwork, bool do_post_comp) const {
1139+
bool do_postwork, bool do_post_comp, const float *dst_scales) const {
11371140

11381141
const auto _pd = pd();
11391142
const auto &jcp = _pd->jcp_;
@@ -1160,6 +1163,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
11601163
p.dst_orig = dst;
11611164
p.c_zp_values = dst_zp_ptr;
11621165
p.a_comp_val = src_zp_vals;
1166+
p.ptr_dst_scales = (void *)dst_scales;
11631167
}
11641168

11651169
auto call_outwork_ker = [&](bool is_postwork, bool has_postcomp,
@@ -1246,7 +1250,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
12461250
static_cast<size_t>(g_oc), 0, btc.brgemm_ctx.dst, 0,
12471251
static_cast<void *>(src_zp_ptr), nullptr,
12481252
static_cast<void *>(dst_zp_ptr), false, src_zp_vals,
1249-
do_only_comp, do_only_pass_comp};
1253+
do_only_comp, do_only_pass_comp, btc.dst_scales};
12501254

12511255
void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile)
12521256
: static_cast<void *>(s8s8_comp);
@@ -1581,7 +1585,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
15811585
g_oc, is_oc_tail, ow_b, ow_e, kd_l, kh_l,
15821586
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
15831587
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1584-
btc.s8s8_comp_ptr, do_init, do_postwork, false);
1588+
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
15851589
};
15861590

15871591
if (kd_f > kd_s && kh_f > kh_s && kw_f > kw_s) {
@@ -1636,7 +1640,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
16361640
g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
16371641
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
16381642
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1639-
btc.s8s8_comp_ptr, do_init, do_postwork, false);
1643+
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
16401644
}
16411645
}
16421646

@@ -1792,7 +1796,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
17921796
g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
17931797
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
17941798
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1795-
btc.s8s8_comp_ptr, do_init, do_postwork, false);
1799+
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
17961800
}
17971801
}
17981802

@@ -1944,7 +1948,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(
19441948
g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
19451949
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
19461950
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1947-
btc.s8s8_comp_ptr, do_init, do_postwork, false);
1951+
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
19481952
}
19491953
}
19501954

src/cpu/x64/jit_brgemm_conv.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2022 Intel Corporation
2+
* Copyright 2021-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -105,7 +105,8 @@ struct brgemm_convolution_fwd_t : public primitive_t {
105105

106106
protected:
107107
bool arg_scales_ok() const {
108-
std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
108+
std::vector<int> supported_args
109+
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
109110
return attr_scales_ok(supported_args);
110111
}
111112

@@ -181,7 +182,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {
181182
const void *post_ops_binary_rhs_arg_vec, const float *oscales,
182183
int32_t src_zp_vals, int32_t *src_zp_ptr, int32_t *dst_zp_ptr,
183184
int32_t *s8s8_compensation, bool maybe_do_init, bool do_postwork,
184-
bool do_post_comp) const;
185+
bool do_post_comp, const float *dst_scales) const;
185186

186187
void call_brgemm_kernel(brgemm_thread_ctx_t &btc, int brg_idx,
187188
int batch_size, char *ptr_C, char *ptr_D, const char *bias_w,

src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022 Intel Corporation
2+
* Copyright 2022-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
2020
#include "common/type_helpers.hpp"
2121
#include "common/utils.hpp"
2222
#include "cpu/cpu_primitive.hpp"
23+
#include "cpu/scale_utils.hpp"
2324

2425
#include "cpu/x64/jit_brgemm_conv_bwd_strided.hpp"
2526
#include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp"
@@ -369,8 +370,12 @@ status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::execute(
369370
const auto _pd = pd();
370371
const auto &jcp = _pd->jcp_;
371372

372-
// XXX: brgemm requires scales to be passed, so passing default wei scales
373-
DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS);
373+
DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
374+
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
375+
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
376+
377+
const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
378+
src_scales, wei_scales, _pd->IC(), _pd->attr());
374379

375380
const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
376381
brgemm_batch_element_t *const __restrict brg_batch_global
@@ -451,6 +456,7 @@ status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::execute(
451456
btc.ihb = ihb;
452457
btc.iwb = iwb;
453458
btc.oscales = oscales;
459+
btc.dst_scales = dst_scales;
454460

455461
auto id_begin = idb * jcp.id_block;
456462
auto id_end = nstl::min(ID, id_begin + jcp.id_block);
@@ -526,7 +532,7 @@ void brgemm_convolution_bwd_strided_t<isa, enable_postops>::call_brgemm_kernel(
526532
static_cast<size_t>(g_ic), 0, btc.brgemm_ctx.dst, 0,
527533
static_cast<void *>(src_zp_ptr), nullptr,
528534
static_cast<void *>(dst_zp_ptr), do_skip_accm, src_zp_vals,
529-
do_only_comp, do_only_pass_comp};
535+
do_only_comp, do_only_pass_comp, btc.dst_scales};
530536

531537
void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile)
532538
: static_cast<void *>(s8s8_comp);

src/cpu/x64/jit_brgemm_conv_bwd_strided.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022 Intel Corporation
2+
* Copyright 2022-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -145,6 +145,7 @@ struct brgemm_convolution_bwd_strided_t : public primitive_t {
145145
int occ;
146146
int sw;
147147
const float *oscales {nullptr};
148+
const float *dst_scales {nullptr};
148149
};
149150

150151
void ker_trans(brgemm_bwd_thread_ctx_t &btc, char *inp_buffer) const;

src/cpu/x64/jit_brgemm_conv_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2022 Intel Corporation
2+
* Copyright 2021-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.

src/cpu/x64/jit_brgemm_inner_product.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2020-2022 Intel Corporation
2+
* Copyright 2020-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -78,6 +78,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
7878

7979
DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
8080
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
81+
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
8182

8283
const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
8384
src_scales, wei_scales, pd()->OC(), pd()->attr());
@@ -108,7 +109,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
108109

109110
const bool are_post_ops_applicable = one_of(true, jbgp.with_sum,
110111
jbgp.with_bias, jbgp.with_scales, jbgp.with_eltwise,
111-
jbgp.with_binary, jbgp.acc_dt != jbgp.dst_dt, jbgp.signed_input);
112+
jbgp.with_binary, jbgp.acc_dt != jbgp.dst_dt, jbgp.signed_input,
113+
jbgp.with_dst_scales);
112114

113115
size_t offset = types::data_type_size(jbgp.wei_dt)
114116
* (weights_d.size() - weights_d.additional_buffer_size());
@@ -221,7 +223,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
221223
static_cast<const void *>(ptr_bias),
222224
&oscales[jbgp.is_oc_scale * oc],
223225
post_ops_binary_rhs_arg_vec.data(),
224-
static_cast<size_t>(oc), 0, dst};
226+
static_cast<size_t>(oc), 0, dst, 0, nullptr, nullptr,
227+
nullptr, false, 1, false, false, dst_scales};
225228

226229
brgemm_kernel_execute_postops(brg_kernel, gemm_batch,
227230
addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data,
@@ -264,7 +267,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
264267
static_cast<const void *>(ptr_bias),
265268
&oscales[jbgp.is_oc_scale * oc],
266269
post_ops_binary_rhs_arg_vec.data(),
267-
static_cast<size_t>(oc), 0, dst};
270+
static_cast<size_t>(oc), 0, dst, 0, nullptr, nullptr,
271+
nullptr, false, 1, false, false, dst_scales};
268272

269273
brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch,
270274
(void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
@@ -457,7 +461,9 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
457461
&oscales[jbgp.is_oc_scale * oc],
458462
post_ops_binary_rhs_arg_vec.data(),
459463
static_cast<size_t>(oc), 0, dst, 0, nullptr,
460-
nullptr, nullptr, true /* skip_accm */};
464+
nullptr, nullptr, true /* skip_accm */, 1,
465+
false, false, dst_scales};
466+
461467
brgemm_kernel_execute_postops(brg_kernel, 0,
462468
nullptr, (void *)ptr_C, (void *)ptr_D,
463469
post_ops_data, scratch);

src/cpu/x64/jit_brgemm_inner_product.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2020-2022 Intel Corporation
2+
* Copyright 2020-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -81,7 +81,7 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
8181
bool are_post_ops_applicable = one_of(true, jbgp_.with_sum,
8282
jbgp_.with_bias, jbgp_.with_scales, jbgp_.with_eltwise,
8383
jbgp_.with_binary, jbgp_.acc_dt != jbgp_.dst_dt,
84-
jbgp_.signed_input);
84+
jbgp_.signed_input, jbgp_.with_dst_scales);
8585

8686
const float alpha = 1.0;
8787
const float beta = 1.0;
@@ -142,7 +142,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
142142
}
143143

144144
bool arg_scales_ok() const {
145-
std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
145+
std::vector<int> supported_args
146+
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
146147
return attr_scales_ok(supported_args);
147148
}
148149

0 commit comments

Comments
 (0)