Skip to content

Commit 5587f08

Browse files
committed
gpu: jit: gemm: only QW-align widths for QW-aligned data
1 parent f5ff0a6 commit 5587f08

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5770,9 +5770,12 @@ static inline bool canRelAddr(const RegisterBlock &blockSrc,
57705770
}
57715771

57725772
static inline int block2DWidthAlignment(Type T, const RegisterBlock &block,
5773+
const MatrixAddressing &atype,
57735774
const MatrixAddressingStrategy &astrategy) {
57745775
// Block 2D width must be DW-aligned, but generally use QW alignment for better performance for reads.
5775-
return ((astrategy.noExtraPad || block.writable) ? 4 : 8);
5776+
return ((astrategy.noExtraPad || block.writable || atype.alignment % 8)
5777+
? 4
5778+
: 8);
57765779
}
57775780

57785781
// Output code for setting up address/header GRFs for a single block, given
@@ -5989,7 +5992,7 @@ void gemm_kernel_generator_t<hw>::setupAddr(Type T, const GRFRange &addr,
59895992
if (doBaseAdjust && !astrategy.address2D) stub();
59905993
Subregister baStorage, baseAdjust, baseAdjustElems;
59915994

5992-
int widthAlign = block2DWidthAlignment(T, block, astrategy);
5995+
int widthAlign = block2DWidthAlignment(T, block, atype, astrategy);
59935996

59945997
if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u);
59955998

@@ -6729,6 +6732,7 @@ void gemm_kernel_generator_t<hw>::remaskLayout(Type T, int index, bool column,
67296732
}
67306733

67316734
static bool needsRemask(Type T, bool column, const RegisterBlock &block,
6735+
const MatrixAddressing &atype,
67326736
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
67336737
if (!ignoreMasks)
67346738
if (column ? !block.remainderC : !block.remainderR) return false;
@@ -6740,19 +6744,20 @@ static bool needsRemask(Type T, bool column, const RegisterBlock &block,
67406744
int maskGranularity = block.ebytes;
67416745
if (block.ebytes >= 16) maskGranularity = 4;
67426746
if (block2DRemask)
6743-
maskGranularity = std::max(
6744-
maskGranularity, block2DWidthAlignment(T, block, astrategy));
6747+
maskGranularity = std::max(maskGranularity,
6748+
block2DWidthAlignment(T, block, atype, astrategy));
67456749
if (ignoreMasks && !(block2DRemask && astrategy.address2D))
67466750
maskGranularity = 256;
67476751

67486752
return (T.size() < maskGranularity);
67496753
}
67506754

67516755
static bool needsRemask(Type T, bool column,
6752-
const vector<RegisterBlock> &layout,
6756+
const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
67536757
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
67546758
for (auto &block : layout)
6755-
if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true;
6759+
if (needsRemask(T, column, block, atype, astrategy, ignoreMasks))
6760+
return true;
67566761
return false;
67576762
}
67586763

@@ -13613,11 +13618,11 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
1361313618
bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded;
1361413619
bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded;
1361513620
slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy
13616-
&& needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai_strategy,
13617-
asIfMaskedAi);
13621+
&& needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai,
13622+
state.Ai_strategy, asIfMaskedAi);
1361813623
slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy
13619-
&& needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi_strategy,
13620-
asIfMaskedBi);
13624+
&& needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi,
13625+
state.Bi_strategy, asIfMaskedBi);
1362113626
}
1362213627

1362313628
static inline void kLoopModifiedFlagAP(GEMMState &state) {
@@ -14376,11 +14381,11 @@ void gemm_kernel_generator_t<hw>::kLoop(KLoop type, const GEMMProblem &problem,
1437614381

1437714382
// A/B remasking in k dimension, during remainder handling.
1437814383
bool remaskA = !slmA && readA && (minOPCount > 1)
14379-
&& needsRemask(Ta_load, true, state.A_layoutRem, strategy.A,
14380-
state.A_lateKRem);
14384+
&& needsRemask(Ta_load, true, state.A_layoutRem, problem.A,
14385+
strategy.A, state.A_lateKRem);
1438114386
bool remaskB = !slmB && readB && (minOPCount > 1)
14382-
&& needsRemask(Tb_load, false, state.B_layoutRem, strategy.B,
14383-
state.B_lateKRem);
14387+
&& needsRemask(Tb_load, false, state.B_layoutRem, problem.B,
14388+
strategy.B, state.B_lateKRem);
1438414389

1438514390
if (Ta.isInteger() && Tb.isInteger() && !calcASums && !calcBSums) {
1438614391
// Only need to remask one operand for integer A/B. Choose the smaller one.

0 commit comments

Comments
 (0)