@@ -5847,9 +5847,12 @@ static inline bool canRelAddr(const RegisterBlock &blockSrc,
58475847}
58485848
58495849static inline int block2DWidthAlignment(Type T, const RegisterBlock &block,
5850+ const MatrixAddressing &atype,
58505851 const MatrixAddressingStrategy &astrategy) {
58515852 // Block 2D width must be DW-aligned, but generally use QW alignment for better performance for reads.
5852- return ((astrategy.noExtraPad || block.writable) ? 4 : 8);
5853+ return ((astrategy.noExtraPad || block.writable || atype.alignment % 8)
5854+ ? 4
5855+ : 8);
58535856}
58545857
58555858static inline int block2DBaseAlignment(HW hw, int stepping) {
@@ -6071,7 +6074,7 @@ void gemm_kernel_generator_t<hw>::setupAddr(Type T, const GRFRange &addr,
60716074 if (doBaseAdjust && !astrategy.address2D) stub();
60726075 Subregister baStorage, baseAdjust, baseAdjustElems;
60736076
6074- int widthAlign = block2DWidthAlignment(T, block, astrategy);
6077+ int widthAlign = block2DWidthAlignment(T, block, atype, astrategy);
60756078
60766079 if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u);
60776080
@@ -6828,6 +6831,7 @@ void gemm_kernel_generator_t<hw>::remaskLayout(Type T, int index, bool column,
68286831}
68296832
68306833static bool needsRemask(Type T, bool column, const RegisterBlock &block,
6834+ const MatrixAddressing &atype,
68316835 const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
68326836 if (!ignoreMasks)
68336837 if (column ? !block.remainderC : !block.remainderR) return false;
@@ -6839,19 +6843,20 @@ static bool needsRemask(Type T, bool column, const RegisterBlock &block,
68396843 int maskGranularity = block.ebytes;
68406844 if (block.ebytes >= 16) maskGranularity = 4;
68416845 if (block2DRemask)
6842- maskGranularity = std::max(
6843- maskGranularity, block2DWidthAlignment(T, block, astrategy));
6846+ maskGranularity = std::max(maskGranularity,
6847+ block2DWidthAlignment(T, block, atype , astrategy));
68446848 if (ignoreMasks && !(block2DRemask && astrategy.address2D))
68456849 maskGranularity = 256;
68466850
68476851 return (T.paddedSize() < maskGranularity);
68486852}
68496853
68506854static bool needsRemask(Type T, bool column,
6851- const vector<RegisterBlock> &layout,
6855+ const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
68526856 const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
68536857 for (auto &block : layout)
6854- if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true;
6858+ if (needsRemask(T, column, block, atype, astrategy, ignoreMasks))
6859+ return true;
68556860 return false;
68566861}
68576862
@@ -14483,11 +14488,11 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
1448314488 bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded;
1448414489 bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded;
1448514490 slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy
14486- && needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai_strategy ,
14487- asIfMaskedAi);
14491+ && needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai ,
14492+ state.Ai_strategy, asIfMaskedAi);
1448814493 slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy
14489- && needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi_strategy ,
14490- asIfMaskedBi);
14494+ && needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi ,
14495+ state.Bi_strategy, asIfMaskedBi);
1449114496}
1449214497
1449314498static inline void kLoopModifiedFlagAP(GEMMState &state) {
@@ -15341,11 +15346,11 @@ void gemm_kernel_generator_t<hw>::kLoop(KLoop type, const GEMMProblem &problem,
1534115346
1534215347 // A/B remasking in k dimension, during remainder handling.
1534315348 bool remaskA = !slmA && readA && (minOPCount > 1)
15344- && needsRemask(Ta_load, true, state.A_layoutRem, strategy .A,
15345- state.A_lateKRem);
15349+ && needsRemask(Ta_load, true, state.A_layoutRem, problem .A,
15350+ strategy.A, state.A_lateKRem);
1534615351 bool remaskB = !slmB && readB && (minOPCount > 1)
15347- && needsRemask(Tb_load, false, state.B_layoutRem, strategy .B,
15348- state.B_lateKRem);
15352+ && needsRemask(Tb_load, false, state.B_layoutRem, problem .B,
15353+ strategy.B, state.B_lateKRem);
1534915354
1535015355 if (Ta.isInteger() && Tb.isInteger() && !calcASums && !calcBSums) {
1535115356 // Only need to remask one operand for integer A/B. Choose the smaller one.
0 commit comments