Skip to content

Commit 668abae

Browse files
kealan-barbierikarturov
authored andcommitted
gpu: jit: reorder: enable any float to hf8, fixup
1 parent c3972ef commit 668abae

File tree

4 files changed

+67
-26
lines changed

4 files changed

+67
-26
lines changed

src/gpu/jit/codegen/reorder.hpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
259259
int dst_type_size = ngen::getBytes(dst_type);
260260
int src_stride_bytes = src_stride * src_type_size;
261261
int dst_stride_bytes = dst_stride * dst_type_size;
262+
int max_type_size = std::max(src_type_size, dst_type_size);
262263
bool dst_b = ngen_is_b(dst_type);
263264
bool dst_d = ngen_is_dw(dst_type);
264265
bool dst_q = ngen_is_qw(dst_type);
@@ -408,12 +409,22 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
408409
return;
409410
}
410411
// hf8 -> f16
411-
if (src_hf8 && dst_hf) {
412+
if (src_hf8) {
412413
int step = get_step();
413414
const int src_stride_bytes = src_stride;
414415
const int dst_stride_bytes = 2 * dst_stride;
415416
const int step_nregs
416417
= utils::div_up(step * ((int)sizeof(ngen::half)), grf_size);
418+
const bool do_post_reorder = !dst_hf;
419+
const int nregs = utils::div_up(width
420+
* std::max((int)sizeof(ngen::half), max_type_size)
421+
* std::max(src_stride, dst_stride),
422+
grf_size);
423+
if (do_post_reorder) {
424+
auto tmp_dst = lex_scope.alloc_reg_buf_data(nregs).format(
425+
0, ngen::DataType::hf);
426+
dst = std::move(tmp_dst);
427+
}
417428
auto tmp1 = lex_scope.alloc_reg_buf_data(step_nregs);
418429
auto tmp2 = lex_scope.alloc_reg_buf_data(step_nregs);
419430
for (int i = 0; i < width; i += step) {
@@ -451,30 +462,54 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
451462
host->mov(esize, d.reinterpret(0, ngen::DataType::uw)(dst_stride),
452463
tmp2.subregister(0, ngen::DataType::uw)(dst_stride));
453464
}
465+
if (do_post_reorder) {
466+
emit_reorder_1d_tile(
467+
hw, host, scope, width, dst, dst_stride, _dst, dst_stride);
468+
}
454469
return;
455470
}
456471

457-
if (src_hf && dst_hf8) {
472+
if (dst_hf8) {
458473
int step = get_step();
459474
const int src_stride_bytes = 2 * src_stride;
460475
const int dst_stride_bytes = dst_stride;
461476
const int step_nregs
462477
= utils::div_up(step * ((int)sizeof(ngen::half)), grf_size);
463478
auto tmp1 = lex_scope.alloc_reg_buf_data(step_nregs);
464479
auto tmp2 = lex_scope.alloc_reg_buf_data(step_nregs);
480+
const bool do_pre_reorder = !src_hf;
481+
const int nregs = utils::div_up(width
482+
483+
* std::max((int)sizeof(ngen::half), max_type_size)
484+
* std::max(src_stride, dst_stride),
485+
grf_size);
486+
if (do_pre_reorder) {
487+
auto tmp_src = lex_scope.alloc_reg_buf_data(nregs).format(
488+
0, ngen::DataType::hf);
489+
emit_reorder_1d_tile(hw, host, scope, width, src, src_stride,
490+
tmp_src, src_stride);
491+
src = std::move(tmp_src);
492+
}
465493
for (int i = 0; i < width; i += step) {
466494
step = std::min(step, width - i);
467495
step = utils::rnd_down_pow2(step);
468496
int esize = step;
469497

470498
auto s = src.subregister(i, esize, src_stride_bytes);
471499
auto d = dst.subregister(i, esize, dst_stride_bytes);
472-
473-
host->mov(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
474-
s.reinterpret(0, ngen::DataType::uw)(src_stride));
500+
if (src_stride > 1 && s.getByteOffset() > 1) {
501+
host->mov(esize,
502+
tmp1.subregister(0, ngen::DataType::uw)(src_stride),
503+
s.reinterpret(0, ngen::DataType::uw)(src_stride));
504+
host->mov(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
505+
tmp1.subregister(0, ngen::DataType::uw)(src_stride));
506+
} else {
507+
host->mov(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
508+
s.reinterpret(0, ngen::DataType::uw)(src_stride));
509+
}
475510
// get sign bits
476-
host->and_(esize | host->nz | host->f1[1], host->null.uw(),
477-
s.reinterpret(0, ngen::DataType::uw)(1), 0x8000);
511+
host->and_(esize | host->nz | host->f2[0], host->null.uw(),
512+
tmp1.subregister(0, ngen::DataType::uw)(1), 0x8000);
478513
// multiply by hf 128 to force overflow of exponent
479514
host->mul(esize, tmp1.subregister(0, ngen::DataType::hf)(1),
480515
tmp1.subregister(0, ngen::DataType::hf)(1),
@@ -487,22 +522,21 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
487522
// check for NaN, inf.
488523
host->and_(esize | host->ze | host->f0[0], host->null.uw(),
489524
~tmp1.subregister(0, ngen::DataType::uw)(1), 0x7C00);
490-
// check for zero mantissa.
491-
host->and_(esize | host->ze | host->f1[0], host->null.uw(),
492-
tmp1.subregister(0, ngen::DataType::uw)(1), 0x7F);
493525
// round.
494-
host->add(esize | host->f1[0],
495-
tmp1.subregister(0, ngen::DataType::uw)(1),
526+
host->add(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
496527
tmp1.subregister(0, ngen::DataType::uw)(1), -0x40);
528+
// check for zero mantissa.
529+
host->and_(esize | host->nz | host->f1[0], host->null.uw(),
530+
tmp1.subregister(0, ngen::DataType::uw)(1), 0x3FF);
497531
host->eshr(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
498-
tmp1.subregister(0, ngen::DataType::uw)(src_stride), 7);
532+
tmp1.subregister(0, ngen::DataType::uw)(1), 7);
499533
host->add(esize | host->f1[0],
500534
tmp1.subregister(0, ngen::DataType::uw)(1),
501535
tmp1.subregister(0, ngen::DataType::uw)(1), 1);
502536
host->mov(esize | host->f0[0],
503537
tmp1.subregister(0, ngen::DataType::uw)(1), 0x7F);
504538
// handle sign.
505-
host->or_(esize | host->f1[1],
539+
host->or_(esize | host->f2[0],
506540
tmp1.subregister(0, ngen::DataType::uw)(1),
507541
tmp1.subregister(0, ngen::DataType::uw)(1), 0x80);
508542

@@ -519,7 +553,6 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
519553
// x <-> bf8
520554
if (src_bf8 || dst_bf8) {
521555
int step = get_step();
522-
int max_type_size = std::max(src_type_size, dst_type_size);
523556
ngen::DataType src_raw
524557
= src_bf8 ? ngen::DataType::ub : ngen::DataType::w;
525558
ngen::DataType dst_raw

src/gpu/jit/gemm/gen_gemm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct gen_gemm_t : public gpu_gemm_t {
139139
ok = ok && d->b_type() == bf16
140140
&& utils::one_of(d->c_type(), bf16, f32)
141141
&& utils::one_of(d->acc_type, bf16, f32);
142-
} else if (!wei_decomp_) {
142+
} else if (!wei_decomp) {
143143
ok = ok
144144
&& utils::one_of(
145145
d->a_type(), f32, f16, f8_e5m2, f8_e4m3)

src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25470,7 +25470,7 @@ bool gemm_kernel_generator_t<hw>::copyRegisters(Type Ts, Type Td,
2547025470

2547125471
const int nphases = 2, qCXMin = -1, qCXMax = -1;
2547225472

25473-
Subregister saveF0;
25473+
Subregister saveF0, saveF1, saveF2;
2547425474
bool releaseEmuFlag = false;
2547525475
bool preswizzle = (hw >= HW::XeHP);
2547625476
GRFRange copyTemp;

src/gpu/jit/reorder/gen_reorder.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ status_t gen_reorder_t::pd_t::init(
4343
auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
4444
auto *device_info = compute_engine->device_info();
4545
zero_points_config_t zp_cfg(this);
46+
using namespace data_type;
4647

4748
auto post_ops_ok = [&]() {
4849
const auto &po = attr()->post_ops_;
@@ -65,30 +66,37 @@ status_t gen_reorder_t::pd_t::init(
6566
&& (!zp_cfg.do_dst_compensation
6667
|| zp_cfg.is_common_dst_zero_point);
6768
};
68-
auto is_bf16_or_f32_or_bf8 = [](data_type_t dt) {
69-
return utils::one_of(dt, data_type::bf16, data_type::f32,
70-
data_type::f8_e5m2, data_type::f8_e4m3);
69+
auto is_bf16_or_f32_or_f8 = [](data_type_t dt) {
70+
return utils::one_of(dt, bf16, f32, f8_e5m2, f8_e4m3);
71+
};
72+
auto hf8_ok = [&]() {
73+
bool any_hf8 = utils::one_of(f8_e4m3, dst_dt, src_dt);
74+
return IMPLICATION(any_hf8,
75+
utils::everyone_is(f8_e4m3, dst_dt, src_dt)
76+
|| utils::one_of(src_dt, bf16, f16, f32)
77+
|| utils::one_of(dst_dt, bf16, f16, f32));
7178
};
72-
bool any_hf8 = utils::one_of(data_type::f8_e4m3, dst_dt, src_dt);
7379
auto skip_mask = dnnl_primitive_attr::skip_mask_t::post_ops
7480
| dnnl_primitive_attr::skip_mask_t::zero_points_runtime
7581
| dnnl_primitive_attr::skip_mask_t::scales_runtime;
7682
using namespace data_type;
7783
bool ok = src_engine == dst_engine && src_engine->kind() == engine_kind::gpu
78-
&& utils::one_of(src_dt, f32, f16, bf16, f8_e5m2, s32, s8, u8, f64)
79-
&& utils::one_of(dst_dt, f32, f16, bf16, f8_e5m2, s32, s8, u8, f64)
84+
&& utils::one_of(
85+
src_dt, f32, f16, bf16, f8_e5m2, f8_e4m3, s32, s8, u8, f64)
86+
&& utils::one_of(
87+
dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3, s32, s8, u8, f64)
8088
&& IMPLICATION(src_dt == data_type::f16 || dst_dt == data_type::f16,
8189
device_info->has_native(data_type::f16))
8290
&& IMPLICATION(
83-
src_dt == data_type::bf16, is_bf16_or_f32_or_bf8(dst_dt))
91+
src_dt == data_type::bf16, is_bf16_or_f32_or_f8(dst_dt))
8492
&& IMPLICATION(
85-
dst_dt == data_type::bf16, is_bf16_or_f32_or_bf8(src_dt))
93+
dst_dt == data_type::bf16, is_bf16_or_f32_or_f8(src_dt))
8694
&& IMPLICATION(utils::one_of(data_type::f8_e5m2, src_dt, dst_dt),
8795
device_info->has_native(data_type::f8_e5m2))
8896
&& IMPLICATION(src_dt == data_type::f64 || dst_dt == data_type::f64,
8997
device_info->has_native(data_type::f64))
9098
&& attr()->has_default_values(skip_mask) && extra_ok()
91-
&& post_ops_ok() && scales_ok() && zps_ok();
99+
&& post_ops_ok() && scales_ok() && zps_ok() && hf8_ok();
92100
if (!ok) return status::unimplemented;
93101

94102
memory_desc_wrapper src_mdw {src_md()};

0 commit comments

Comments
 (0)