@@ -259,6 +259,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
259259 int dst_type_size = ngen::getBytes (dst_type);
260260 int src_stride_bytes = src_stride * src_type_size;
261261 int dst_stride_bytes = dst_stride * dst_type_size;
262+ int max_type_size = std::max (src_type_size, dst_type_size);
262263 bool dst_b = ngen_is_b (dst_type);
263264 bool dst_d = ngen_is_dw (dst_type);
264265 bool dst_q = ngen_is_qw (dst_type);
@@ -408,12 +409,22 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
408409 return ;
409410 }
410411 // hf8 -> f16
411- if (src_hf8 && dst_hf ) {
412+ if (src_hf8) {
412413 int step = get_step ();
413414 const int src_stride_bytes = src_stride;
414415 const int dst_stride_bytes = 2 * dst_stride;
415416 const int step_nregs
416417 = utils::div_up (step * ((int )sizeof (ngen::half)), grf_size);
418+ const bool do_post_reorder = !dst_hf;
419+ const int nregs = utils::div_up (width
420+ * std::max ((int )sizeof (ngen::half), max_type_size)
421+ * std::max (src_stride, dst_stride),
422+ grf_size);
423+ if (do_post_reorder) {
424+ auto tmp_dst = lex_scope.alloc_reg_buf_data (nregs).format (
425+ 0 , ngen::DataType::hf);
426+ dst = std::move (tmp_dst);
427+ }
417428 auto tmp1 = lex_scope.alloc_reg_buf_data (step_nregs);
418429 auto tmp2 = lex_scope.alloc_reg_buf_data (step_nregs);
419430 for (int i = 0 ; i < width; i += step) {
@@ -451,30 +462,54 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
451462 host->mov (esize, d.reinterpret (0 , ngen::DataType::uw)(dst_stride),
452463 tmp2.subregister (0 , ngen::DataType::uw)(dst_stride));
453464 }
465+ if (do_post_reorder) {
466+ emit_reorder_1d_tile (
467+ hw, host, scope, width, dst, dst_stride, _dst, dst_stride);
468+ }
454469 return ;
455470 }
456471
457- if (src_hf && dst_hf8) {
472+ if (dst_hf8) {
458473 int step = get_step ();
459474 const int src_stride_bytes = 2 * src_stride;
460475 const int dst_stride_bytes = dst_stride;
461476 const int step_nregs
462477 = utils::div_up (step * ((int )sizeof (ngen::half)), grf_size);
463478 auto tmp1 = lex_scope.alloc_reg_buf_data (step_nregs);
464479 auto tmp2 = lex_scope.alloc_reg_buf_data (step_nregs);
480+ const bool do_pre_reorder = !src_hf;
481+ const int nregs = utils::div_up (width
482+
483+ * std::max ((int )sizeof (ngen::half), max_type_size)
484+ * std::max (src_stride, dst_stride),
485+ grf_size);
486+ if (do_pre_reorder) {
487+ auto tmp_src = lex_scope.alloc_reg_buf_data (nregs).format (
488+ 0 , ngen::DataType::hf);
489+ emit_reorder_1d_tile (hw, host, scope, width, src, src_stride,
490+ tmp_src, src_stride);
491+ src = std::move (tmp_src);
492+ }
465493 for (int i = 0 ; i < width; i += step) {
466494 step = std::min (step, width - i);
467495 step = utils::rnd_down_pow2 (step);
468496 int esize = step;
469497
470498 auto s = src.subregister (i, esize, src_stride_bytes);
471499 auto d = dst.subregister (i, esize, dst_stride_bytes);
472-
473- host->mov (esize, tmp1.subregister (0 , ngen::DataType::uw)(1 ),
474- s.reinterpret (0 , ngen::DataType::uw)(src_stride));
500+ if (src_stride > 1 && s.getByteOffset () > 1 ) {
501+ host->mov (esize,
502+ tmp1.subregister (0 , ngen::DataType::uw)(src_stride),
503+ s.reinterpret (0 , ngen::DataType::uw)(src_stride));
504+ host->mov (esize, tmp1.subregister (0 , ngen::DataType::uw)(1 ),
505+ tmp1.subregister (0 , ngen::DataType::uw)(src_stride));
506+ } else {
507+ host->mov (esize, tmp1.subregister (0 , ngen::DataType::uw)(1 ),
508+ s.reinterpret (0 , ngen::DataType::uw)(src_stride));
509+ }
475510 // get sign bits
476- host->and_ (esize | host->nz | host->f1 [ 1 ], host->null .uw (),
477- s. reinterpret (0 , ngen::DataType::uw)(1 ), 0x8000 );
511+ host->and_ (esize | host->nz | host->f2 [ 0 ], host->null .uw (),
512+ tmp1. subregister (0 , ngen::DataType::uw)(1 ), 0x8000 );
478513 // multiply by hf 128 to force overflow of exponent
479514 host->mul (esize, tmp1.subregister (0 , ngen::DataType::hf)(1 ),
480515 tmp1.subregister (0 , ngen::DataType::hf)(1 ),
@@ -487,22 +522,21 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
487522 // check for NaN, inf.
488523 host->and_ (esize | host->ze | host->f0 [0 ], host->null .uw (),
489524 ~tmp1.subregister (0 , ngen::DataType::uw)(1 ), 0x7C00 );
490- // check for zero mantissa.
491- host->and_ (esize | host->ze | host->f1 [0 ], host->null .uw (),
492- tmp1.subregister (0 , ngen::DataType::uw)(1 ), 0x7F );
493525 // round.
494- host->add (esize | host->f1 [0 ],
495- tmp1.subregister (0 , ngen::DataType::uw)(1 ),
526+ host->add (esize, tmp1.subregister (0 , ngen::DataType::uw)(1 ),
496527 tmp1.subregister (0 , ngen::DataType::uw)(1 ), -0x40 );
528+ // check for zero mantissa.
529+ host->and_ (esize | host->nz | host->f1 [0 ], host->null .uw (),
530+ tmp1.subregister (0 , ngen::DataType::uw)(1 ), 0x3FF );
497531 host->eshr (esize, tmp1.subregister (0 , ngen::DataType::uw)(1 ),
498- tmp1.subregister (0 , ngen::DataType::uw)(src_stride ), 7 );
532+ tmp1.subregister (0 , ngen::DataType::uw)(1 ), 7 );
499533 host->add (esize | host->f1 [0 ],
500534 tmp1.subregister (0 , ngen::DataType::uw)(1 ),
501535 tmp1.subregister (0 , ngen::DataType::uw)(1 ), 1 );
502536 host->mov (esize | host->f0 [0 ],
503537 tmp1.subregister (0 , ngen::DataType::uw)(1 ), 0x7F );
504538 // handle sign.
505- host->or_ (esize | host->f1 [ 1 ],
539+ host->or_ (esize | host->f2 [ 0 ],
506540 tmp1.subregister (0 , ngen::DataType::uw)(1 ),
507541 tmp1.subregister (0 , ngen::DataType::uw)(1 ), 0x80 );
508542
@@ -519,7 +553,6 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
519553 // x <-> bf8
520554 if (src_bf8 || dst_bf8) {
521555 int step = get_step ();
522- int max_type_size = std::max (src_type_size, dst_type_size);
523556 ngen::DataType src_raw
524557 = src_bf8 ? ngen::DataType::ub : ngen::DataType::w;
525558 ngen::DataType dst_raw
0 commit comments