-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[AArch64] Lower partial add reduction to udot or svdot #101010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
561706f
[AArch64] Lower add partial reduction to udot
SamTebbs33 563d025
Remove TargetLowers and GenericLowers
SamTebbs33 e604e45
Assert that shouldExpandPartialReductionIntrinsic sees an intrinsic
SamTebbs33 9b23c96
Allow non-scalable vector types
SamTebbs33 45692df
Clean up type checking
SamTebbs33 d305452
Restrict to scalable vector types and clean up type checking
SamTebbs33 4738a20
Simplify instruction matching in shouldExpandPartialReduction
SamTebbs33 4dbf99e
Add fallback in case the nodes aren't as we expect at lowering time
SamTebbs33 c068775
Fix logic error with fallback case
SamTebbs33 636652d
Pass IntrinsicInst to shouldExpandPartialReductionIntrinsic
SamTebbs33 83015b7
Remove one-use restriction
SamTebbs33 ed6efd6
Remove new line
SamTebbs33 6364837
Remove extending/truncating for fallback case
SamTebbs33 9da416b
Clean up test target
SamTebbs33 0d23109
Remove #0 attribute from test
SamTebbs33 bc86de6
Allow i8 to i64 dot products
SamTebbs33 aa7957f
Remove isPartialReductionSupported
SamTebbs33 a58ac29
Share expansion code in SelectionDAG
SamTebbs33 5f31079
Check for NEON or SVE
SamTebbs33 2f3a0dc
Rename expansion function
SamTebbs33 00a1be2
Simplify shouldExpandPartialReductionIntrinsic
SamTebbs33 f6c5839
Remove nxv4i64 case
SamTebbs33 da20b2a
Add assertion
SamTebbs33 4697fc1
Fix subtarget check
SamTebbs33 31b7567
Emit a node instead of an intrinsic
SamTebbs33 7629679
Pass accumulator from function in tests
SamTebbs33 830df76
Remove nxv4i64 case from shouldExpandPartialReductionIntrinsic
SamTebbs33 97f3f76
Reword getPartialReduceAdd comment
SamTebbs33 bccc3e0
Remove blank lines from shouldExpandPartialReductionIntrinsic
SamTebbs33 08db4c3
Rename ExtendedType
SamTebbs33 acf964b
Remove var declaration comments
SamTebbs33 50147e0
Put DAG.getNode all on one line
SamTebbs33 6ed303f
Remove WideType
SamTebbs33 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1988,6 +1988,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT, | |
return false; | ||
} | ||
|
||
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( | ||
const IntrinsicInst *I) const { | ||
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add) | ||
return true; | ||
|
||
EVT VT = EVT::getEVT(I->getType()); | ||
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64; | ||
} | ||
|
||
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove extra line (I would have expected the formatter to remove it). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if (!Subtarget->isSVEorStreamingSVEAvailable()) | ||
return true; | ||
|
@@ -21757,6 +21766,61 @@ static SDValue tryCombineWhileLo(SDNode *N, | |
return SDValue(N, 0); | ||
} | ||
|
||
SDValue tryLowerPartialReductionToDot(SDNode *N, | ||
const AArch64Subtarget *Subtarget, | ||
SelectionDAG &DAG) { | ||
|
||
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && | ||
getIntrinsicID(N) == | ||
Intrinsic::experimental_vector_partial_reduce_add && | ||
"Expected a partial reduction node"); | ||
|
||
if (!Subtarget->isSVEorStreamingSVEAvailable()) | ||
return SDValue(); | ||
|
||
SDLoc DL(N); | ||
|
||
// The narrower of the two operands. Used as the accumulator | ||
auto NarrowOp = N->getOperand(1); | ||
auto MulOp = N->getOperand(2); | ||
if (MulOp->getOpcode() != ISD::MUL) | ||
return SDValue(); | ||
|
||
auto ExtA = MulOp->getOperand(0); | ||
auto ExtB = MulOp->getOperand(1); | ||
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; | ||
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; | ||
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) | ||
return SDValue(); | ||
|
||
auto A = ExtA->getOperand(0); | ||
auto B = ExtB->getOperand(0); | ||
if (A.getValueType() != B.getValueType()) | ||
return SDValue(); | ||
|
||
unsigned Opcode = 0; | ||
|
||
if (IsSExt) | ||
Opcode = AArch64ISD::SDOT; | ||
else if (IsZExt) | ||
Opcode = AArch64ISD::UDOT; | ||
|
||
assert(Opcode != 0 && "Unexpected dot product case encountered."); | ||
|
||
EVT ReducedType = N->getValueType(0); | ||
EVT MulSrcType = A.getValueType(); | ||
|
||
// Dot products operate on chunks of four elements so there must be four times | ||
// as many elements in the wide type | ||
if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) | ||
paulwalker-arm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B); | ||
|
||
if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) | ||
return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B); | ||
|
||
return SDValue(); | ||
} | ||
|
||
static SDValue performIntrinsicCombine(SDNode *N, | ||
TargetLowering::DAGCombinerInfo &DCI, | ||
const AArch64Subtarget *Subtarget) { | ||
|
@@ -21765,6 +21829,12 @@ static SDValue performIntrinsicCombine(SDNode *N, | |
switch (IID) { | ||
default: | ||
break; | ||
case Intrinsic::experimental_vector_partial_reduce_add: { | ||
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) | ||
return Dot; | ||
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0), | ||
N->getOperand(1), N->getOperand(2)); | ||
} | ||
case Intrinsic::aarch64_neon_vcvtfxs2fp: | ||
case Intrinsic::aarch64_neon_vcvtfxu2fp: | ||
return tryCombineFixedPointConvert(N, DCI, DAG); | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | ||
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s | ||
|
||
define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) { | ||
; CHECK-LABEL: dotp: | ||
; CHECK: // %bb.0: // %entry | ||
; CHECK-NEXT: udot z0.s, z1.b, z2.b | ||
; CHECK-NEXT: ret | ||
entry: | ||
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32> | ||
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32> | ||
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide | ||
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult) | ||
ret <vscale x 4 x i32> %partial.reduce | ||
} | ||
|
||
define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) { | ||
; CHECK-LABEL: dotp_wide: | ||
; CHECK: // %bb.0: // %entry | ||
; CHECK-NEXT: udot z0.d, z1.h, z2.h | ||
; CHECK-NEXT: ret | ||
entry: | ||
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64> | ||
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64> | ||
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide | ||
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult) | ||
ret <vscale x 2 x i64> %partial.reduce | ||
} | ||
|
||
define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) { | ||
; CHECK-LABEL: dotp_sext: | ||
; CHECK: // %bb.0: // %entry | ||
; CHECK-NEXT: sdot z0.s, z1.b, z2.b | ||
; CHECK-NEXT: ret | ||
entry: | ||
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32> | ||
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32> | ||
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide | ||
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult) | ||
ret <vscale x 4 x i32> %partial.reduce | ||
} | ||
|
||
define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) { | ||
; CHECK-LABEL: dotp_wide_sext: | ||
; CHECK: // %bb.0: // %entry | ||
; CHECK-NEXT: sdot z0.d, z1.h, z2.h | ||
; CHECK-NEXT: ret | ||
entry: | ||
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64> | ||
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64> | ||
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide | ||
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult) | ||
ret <vscale x 2 x i64> %partial.reduce | ||
} | ||
|
||
define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) { | ||
; CHECK-LABEL: not_dotp: | ||
; CHECK: // %bb.0: // %entry | ||
; CHECK-NEXT: and z1.h, z1.h, #0xff | ||
; CHECK-NEXT: and z2.h, z2.h, #0xff | ||
; CHECK-NEXT: ptrue p0.s | ||
; CHECK-NEXT: uunpklo z3.s, z1.h | ||
; CHECK-NEXT: uunpklo z4.s, z2.h | ||
; CHECK-NEXT: uunpkhi z1.s, z1.h | ||
; CHECK-NEXT: uunpkhi z2.s, z2.h | ||
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s | ||
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s | ||
; CHECK-NEXT: ret | ||
entry: | ||
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32> | ||
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32> | ||
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide | ||
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult) | ||
ret <vscale x 4 x i32> %partial.reduce | ||
} | ||
|
||
define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) { | ||
; CHECK-LABEL: not_dotp_wide: | ||
; CHECK: // %bb.0: // %entry | ||
; CHECK-NEXT: and z1.s, z1.s, #0xffff | ||
; CHECK-NEXT: and z2.s, z2.s, #0xffff | ||
; CHECK-NEXT: ptrue p0.d | ||
; CHECK-NEXT: uunpklo z3.d, z1.s | ||
; CHECK-NEXT: uunpklo z4.d, z2.s | ||
; CHECK-NEXT: uunpkhi z1.d, z1.s | ||
; CHECK-NEXT: uunpkhi z2.d, z2.s | ||
; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d | ||
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d | ||
; CHECK-NEXT: ret | ||
entry: | ||
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64> | ||
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64> | ||
%mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide | ||
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult) | ||
ret <vscale x 2 x i64> %partial.reduce | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: curly braces are unnecessary here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.