-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot #107566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
ac88857
[AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions
SamTebbs33 b118e92
Move ext opcode checks back
SamTebbs33 eb62a5d
Add USDOT ISD node
SamTebbs33 15117fe
Match in ISelLowering
SamTebbs33 7268fe1
Improve tests
SamTebbs33 9dedb73
Simplify BIsZext check
SamTebbs33 15988da
Simplify ext opcode check
SamTebbs33 bb7c880
Fix usdot lane matching
SamTebbs33 79bb429
Format
SamTebbs33 0da208b
Fix BaseSIMDSUDOTIndex pattern
SamTebbs33 9fa2ae7
else if -> else
SamTebbs33 cad2936
Re-jig the run lines again
SamTebbs33 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2701,6 +2701,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { | |
MAKE_CASE(AArch64ISD::SADDLV) | ||
MAKE_CASE(AArch64ISD::SDOT) | ||
MAKE_CASE(AArch64ISD::UDOT) | ||
MAKE_CASE(AArch64ISD::USDOT) | ||
MAKE_CASE(AArch64ISD::SMINV) | ||
MAKE_CASE(AArch64ISD::UMINV) | ||
MAKE_CASE(AArch64ISD::SMAXV) | ||
|
@@ -6114,6 +6115,11 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, | |
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), | ||
Op.getOperand(2), Op.getOperand(3)); | ||
} | ||
case Intrinsic::aarch64_neon_usdot: | ||
case Intrinsic::aarch64_sve_usdot: { | ||
return DAG.getNode(AArch64ISD::USDOT, dl, Op.getValueType(), | ||
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); | ||
} | ||
case Intrinsic::get_active_lane_mask: { | ||
SDValue ID = | ||
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64); | ||
|
@@ -21824,37 +21830,50 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, | |
|
||
auto ExtA = MulOp->getOperand(0); | ||
auto ExtB = MulOp->getOperand(1); | ||
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; | ||
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; | ||
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) | ||
|
||
if (!ISD::isExtOpcode(ExtA->getOpcode()) || | ||
!ISD::isExtOpcode(ExtB->getOpcode())) | ||
return SDValue(); | ||
bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND; | ||
bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND; | ||
|
||
auto A = ExtA->getOperand(0); | ||
auto B = ExtB->getOperand(0); | ||
if (A.getValueType() != B.getValueType()) | ||
return SDValue(); | ||
|
||
unsigned Opcode = 0; | ||
|
||
if (IsSExt) | ||
Opcode = AArch64ISD::SDOT; | ||
else if (IsZExt) | ||
Opcode = AArch64ISD::UDOT; | ||
|
||
assert(Opcode != 0 && "Unexpected dot product case encountered."); | ||
|
||
EVT ReducedType = N->getValueType(0); | ||
EVT MulSrcType = A.getValueType(); | ||
|
||
// Dot products operate on chunks of four elements so there must be four times | ||
// as many elements in the wide type | ||
if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) || | ||
(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) || | ||
(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) || | ||
(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8)) | ||
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B); | ||
if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) && | ||
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) && | ||
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) && | ||
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8)) | ||
return SDValue(); | ||
|
||
return SDValue(); | ||
// If the extensions are mixed, we should lower it to a usdot instead | ||
unsigned Opcode = 0; | ||
if (AIsSigned != BIsSigned) { | ||
if (!Subtarget->hasMatMulInt8()) | ||
return SDValue(); | ||
|
||
bool Scalable = N->getValueType(0).isScalableVT(); | ||
// There's no nxv2i64 version of usdot | ||
if (Scalable && ReducedType != MVT::nxv4i32) | ||
return SDValue(); | ||
|
||
Opcode = AArch64ISD::USDOT; | ||
// USDOT expects the signed operand to be last | ||
if (!BIsSigned) | ||
std::swap(A, B); | ||
} else if (AIsSigned) | ||
Opcode = AArch64ISD::SDOT; | ||
else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's very true, it also lets us get rid of the assertion. Done. |
||
Opcode = AArch64ISD::UDOT; | ||
|
||
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B); | ||
} | ||
|
||
static SDValue performIntrinsicCombine(SDNode *N, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this check is necessary (though nor was it needed before this patch). The code that emits this intrinsic checks this condition (see https://github.com/llvm/llvm-project/pull/92418/files#diff-da321d454a7246f8ae276bf1db2782bf26b5210b8133cb59e4d7fd45d0905decR2156-R2158), so outside of hand-written IR the case of no extends is never taken.
That, and there doesn't seem to be any tests that check this condition either.
That said, I'm fine with this staying as a bit of defensive coding, as I can't say for sure whether all partial reduction cases in the future will match on extends, but figured it was worth bringing forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function must stand on it's own and thus cannot assume anything about its input that is not part of the intrinsic's definition.