Skip to content

[AArch64] Lower partial add reduction to udot or svdot #101010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
561706f
[AArch64] Lower add partial reduction to udot
SamTebbs33 Jul 29, 2024
563d025
Remove TargetLowers and GenericLowers
SamTebbs33 Jul 29, 2024
e604e45
Assert that shouldExpandPartialReductionIntrinsic sees an intrinsic
SamTebbs33 Jul 29, 2024
9b23c96
Allow non-scalable vector types
SamTebbs33 Jul 29, 2024
45692df
Clean up type checking
SamTebbs33 Jul 29, 2024
d305452
Restrict to scalable vector types and clean up type checking
SamTebbs33 Aug 1, 2024
4738a20
Simplify instruction matching in shouldExpandPartialReduction
SamTebbs33 Aug 1, 2024
4dbf99e
Add fallback in case the nodes aren't as we expect at lowering time
SamTebbs33 Aug 9, 2024
c068775
Fix logic error with fallback case
SamTebbs33 Aug 12, 2024
636652d
Pass IntrinsicInst to shouldExpandPartialReductionIntrinsic
SamTebbs33 Aug 13, 2024
83015b7
Remove one-use restriction
SamTebbs33 Aug 13, 2024
ed6efd6
Remove new line
SamTebbs33 Aug 13, 2024
6364837
Remove extending/truncating for fallback case
SamTebbs33 Aug 13, 2024
9da416b
Clean up test target
SamTebbs33 Aug 13, 2024
0d23109
Remove #0 attribute from test
SamTebbs33 Aug 14, 2024
bc86de6
Allow i8 to i64 dot products
SamTebbs33 Aug 14, 2024
aa7957f
Remove isPartialReductionSupported
SamTebbs33 Aug 20, 2024
a58ac29
Share expansion code in SelectionDAG
SamTebbs33 Aug 21, 2024
5f31079
Check for NEON or SVE
SamTebbs33 Aug 21, 2024
2f3a0dc
Rename expansion function
SamTebbs33 Aug 28, 2024
00a1be2
Simplify shouldExpandPartialReductionIntrinsic
SamTebbs33 Aug 28, 2024
f6c5839
Remove nxv4i64 case
SamTebbs33 Aug 29, 2024
da20b2a
Add assertion
SamTebbs33 Aug 29, 2024
4697fc1
Fix subtarget check
SamTebbs33 Aug 29, 2024
31b7567
Emit a node instead of an intrinsic
SamTebbs33 Aug 29, 2024
7629679
Pass accumulator from function in tests
SamTebbs33 Aug 29, 2024
830df76
Remove nxv4i64 case from shouldExpandPartialReductionIntrinsic
SamTebbs33 Aug 30, 2024
97f3f76
Reword getPartialReduceAdd comment
SamTebbs33 Sep 2, 2024
bccc3e0
Remove blank lines from shouldExpandPartialReductionIntrinsic
SamTebbs33 Sep 2, 2024
08db4c3
Rename ExtendedType
SamTebbs33 Sep 2, 2024
acf964b
Remove var declaration comments
SamTebbs33 Sep 2, 2024
50147e0
Put DAG.getNode all on one line
SamTebbs33 Sep 2, 2024
6ed303f
Remove WideType
SamTebbs33 Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,11 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);

/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
/// its operands and ReducedTY is the intrinsic's return type.
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2);

/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
SDValue expandVAArg(SDNode *Node);

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,13 @@ class TargetLoweringBase {
return true;
}

/// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
/// should be expanded using generic code in SelectionDAGBuilder.
virtual bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
return true;
}

/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <deque>
#include <limits>
#include <optional>
#include <set>
Expand Down Expand Up @@ -2426,6 +2427,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}

SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2) {
EVT FullTy = Op2.getValueType();

unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;

// Collect all of the subvectors
std::deque<SDValue> Subvectors = {Op1};
for (unsigned I = 0; I < ScaleFactor; I++) {
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
Subvectors.push_back(
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
}

// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}

assert(Subvectors.size() == 1 &&
"There should only be one subvector after tree flattening");

return Subvectors[0];
}

SDValue SelectionDAG::expandVAArg(SDNode *Node) {
SDLoc dl(Node);
const TargetLowering &TLI = getTargetLoweringInfo();
Expand Down
31 changes: 6 additions & 25 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8005,34 +8005,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
SDValue OpNode = getValue(I.getOperand(1));
EVT ReducedTy = EVT::getEVT(I.getType());
EVT FullTy = OpNode.getValueType();

unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;

// Collect all of the subvectors
std::deque<SDValue> Subvectors;
Subvectors.push_back(getValue(I.getOperand(0)));
for (unsigned i = 0; i < ScaleFactor; i++) {
auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
{OpNode, SourceIndex}));
}

// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
{Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
return;
}

assert(Subvectors.size() == 1 &&
"There should only be one subvector after tree flattening");

setValue(&I, Subvectors[0]);
setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
getValue(I.getOperand(0)),
getValue(I.getOperand(1))));
return;
}
case Intrinsic::experimental_cttz_elts: {
Expand Down
70 changes: 70 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}

bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: curly braces are unnecessary here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const IntrinsicInst *I) const {
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
return true;

EVT VT = EVT::getEVT(I->getType());
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
}

bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra line (I would have expected the formatter to remove it).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
Expand Down Expand Up @@ -21757,6 +21766,61 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}

SDValue tryLowerPartialReductionToDot(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {

assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
getIntrinsicID(N) ==
Intrinsic::experimental_vector_partial_reduce_add &&
"Expected a partial reduction node");

if (!Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();

SDLoc DL(N);

// The narrower of the two operands. Used as the accumulator
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
if (MulOp->getOpcode() != ISD::MUL)
return SDValue();

auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
return SDValue();

auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
return SDValue();

unsigned Opcode = 0;

if (IsSExt)
Opcode = AArch64ISD::SDOT;
else if (IsZExt)
Opcode = AArch64ISD::UDOT;

assert(Opcode != 0 && "Unexpected dot product case encountered.");

EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();

// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);

if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);

return SDValue();
}

static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -21765,6 +21829,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,9 @@ class AArch64TargetLowering : public TargetLowering {

bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;

bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;

bool shouldExpandCttzElements(EVT VT) const override;

/// If a change in streaming mode is required on entry to/return from a
Expand Down
96 changes: 96 additions & 0 deletions llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s

define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}

define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}

define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEXT: and z2.h, z2.h, #0xff
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: uunpklo z3.s, z1.h
; CHECK-NEXT: uunpklo z4.s, z2.h
; CHECK-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEXT: uunpkhi z2.s, z2.h
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
; CHECK-LABEL: not_dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEXT: and z2.s, z2.s, #0xffff
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: uunpklo z3.d, z1.s
; CHECK-NEXT: uunpklo z4.d, z2.s
; CHECK-NEXT: uunpkhi z1.d, z1.s
; CHECK-NEXT: uunpkhi z2.d, z2.s
; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
%mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}