-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[LoopVectorizer] Add support for chaining partial reductions #120272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoopVectorizer] Add support for chaining partial reductions #120272
Conversation
d5d2c0d
to
34d5f25
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
As a side effect of this work being motivated by supporting complex dot product instructions (and thus allowing use of straight loads instead of gathers/deinterleaving loads), this PR also introduces support for partial reduction subtractions via negating the second operand of |
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesChaining partial reductions, where multiple partial reductions share an accumulator, allow for more values to be combined together as part of the reduction without discarding the semantics of the partial reduction itself. Patch is 45.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120272.diff 6 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 8e7e590c173ff2..c6cebcca679353 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -368,7 +368,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);
- if (Opcode != Instruction::Add)
+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
return Invalid;
if (InputTypeA != InputTypeB)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 99f6a8860f0f4d..a18cf3c9bec2bd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8790,12 +8790,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
/// are valid so recipes can be formed later.
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// Find all possible partial reductions.
- SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
+ SmallVector<std::pair<PartialReductionChain, unsigned>>
PartialReductionChains;
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
- if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
- getScaledReduction(Phi, RdxDesc, Range))
- PartialReductionChains.push_back(*Pair);
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
+ if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
+ PartialReductionChains.append(*SR);
+ }
// A partial reduction is invalid if any of its extends are used by
// something that isn't another partial reduction. This is because the
@@ -8823,26 +8823,41 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
}
}
-std::optional<std::pair<PartialReductionChain, unsigned>>
-VPRecipeBuilder::getScaledReduction(PHINode *PHI,
- const RecurrenceDescriptor &Rdx,
+std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
+VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
VFRange &Range) {
+
+ if (!CM.TheLoop->contains(RdxExitInstr))
+ return std::nullopt;
+
// TODO: Allow scaling reductions when predicating. The select at
// the end of the loop chooses between the phi value and most recent
// reduction result, both of which have different VFs to the active lane
// mask when scaling.
- if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
+ if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
return std::nullopt;
- auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
if (!Update)
return std::nullopt;
Value *Op = Update->getOperand(0);
Value *PhiOp = Update->getOperand(1);
- if (Op == PHI) {
- Op = Update->getOperand(1);
- PhiOp = Update->getOperand(0);
+ if (Op == PHI)
+ std::swap(Op, PhiOp);
+
+ SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
+
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
+ if (auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
+ Chains.append(*SR0);
+ PHI = SR0->rbegin()->first.Reduction;
+
+ Op = Update->getOperand(0);
+ PhiOp = Update->getOperand(1);
+ if (Op == PHI)
+ std::swap(Op, PhiOp);
+ }
}
if (PhiOp != PHI)
return std::nullopt;
@@ -8860,12 +8875,16 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+ // Check that the extends extend from the same type.
+ if (A->getType() != B->getType())
+ return std::nullopt;
+
TTI::PartialReductionExtendKind OpAExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
TTI::PartialReductionExtendKind OpBExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
- PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
+ PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8880,9 +8899,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
return Cost.isValid();
},
Range))
- return std::make_pair(Chain, TargetScaleFactor);
+ Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
- return std::nullopt;
+ return Chains;
}
VPRecipeBase *
@@ -8979,7 +8998,9 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
VPValue *BinOp = Operands[0];
VPValue *Phi = Operands[1];
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
+ isa<VPPartialReductionRecipe>(BinOpRecipe))
std::swap(BinOp, Phi);
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index cf653e2d3e6584..6be9a716cacbfd 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -144,8 +144,8 @@ class VPRecipeBuilder {
/// Returns null if no scaled reduction was found, otherwise a pair with a
/// struct containing reduction information and the scaling factor between the
/// number of elements in the input and output.
- std::optional<std::pair<PartialReductionChain, unsigned>>
- getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
+ std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
+ getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
VFRange &Range);
public:
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 87f87bf1437196..9e09bfc62105c1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2453,13 +2453,16 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
Opcode(Opcode) {
- assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
+ auto *DefiningRecipe = getOperand(1)->getDefiningRecipe();
+ assert((isa<VPReductionPHIRecipe>(DefiningRecipe) ||
+ isa<VPPartialReductionRecipe>(DefiningRecipe)) &&
"Unexpected operand order for partial reduction recipe");
}
~VPPartialReductionRecipe() override = default;
VPPartialReductionRecipe *clone() override {
- return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1));
+ return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1),
+ getUnderlyingInstr());
}
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 979a8e0768a991..668c317033fe78 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -317,13 +317,22 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
State.setDebugLocFrom(getDebugLoc());
auto &Builder = State.Builder;
- assert(getOpcode() == Instruction::Add &&
- "Unhandled partial reduction opcode");
-
Value *BinOpVal = State.get(getOperand(0));
Value *PhiVal = State.get(getOperand(1));
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
+ auto Opcode = getOpcode();
+
+ // Currently we don't have a partial_reduce_sub intrinsic,
+ // so mimic the behaviour by negating the second operand
+ if (Opcode == Instruction::Sub) {
+ BinOpVal = Builder.CreateSub(Constant::getNullValue(BinOpVal->getType()),
+ BinOpVal);
+ Opcode = Instruction::Add;
+ }
+
+ assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
+
Type *RetTy = PhiVal->getType();
CallInst *V = Builder.CreateIntrinsic(
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-chained.ll b/llvm/test/CodeGen/AArch64/partial-reduction-chained.ll
new file mode 100644
index 00000000000000..4272e2f7552495
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-chained.ll
@@ -0,0 +1,568 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -o - %s --passes=loop-vectorize -force-vector-width=16 -force-vector-interleave=1 -scalable-vectorization=on | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: read) vscale_range(1,16)
+define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-LABEL: define i32 @chained_partial_reduce_add_sub(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP8]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 16 x i8>, ptr [[TMP13]], align 1
+; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP9]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD6:%.*]] = load <vscale x 16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT: [[TMP25:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP18:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD3]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP27:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD6]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP29:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP18]]
+; CHECK-NEXT: [[PARTIAL_REDUCE7:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI1]], <vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT: [[TMP31:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP27]]
+; CHECK-NEXT: [[TMP33:%.*]] = sub <vscale x 16 x i32> zeroinitializer, [[TMP31]]
+; CHECK-NEXT: [[PARTIAL_REDUCE9]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE7]], <vscale x 16 x i32> [[TMP33]])
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE9]])
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP23]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
+; CHECK: for.cond.cleanup:
+; CHECK-NEXT: [[RES_0_LCSSA:%.*]] = phi i32 [ [[SUB:%.*]], [[FOR_BODY]] ], [ [[TMP23]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: ret i32 [[RES_0_LCSSA]]
+; CHECK: for.body:
+; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT: [[RES:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[SUB]], [[FOR_BODY]] ]
+; CHECK-NEXT: [[A_PTR:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT: [[B_PTR:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT: [[C_PTR:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDVARS_IV]]
+; CHECK-NEXT: [[B_VAL:%.*]] = load i8, ptr [[A_PTR]], align 1
+; CHECK-NEXT: [[C_VAL:%.*]] = load i8, ptr [[B_PTR]], align 1
+; CHECK-NEXT: [[D_VAL:%.*]] = load i8, ptr [[C_PTR]], align 1
+; CHECK-NEXT: [[B_EXT:%.*]] = sext i8 [[B_VAL]] to i32
+; CHECK-NEXT: [[C_EXT:%.*]] = sext i8 [[C_VAL]] to i32
+; CHECK-NEXT: [[D_EXT:%.*]] = sext i8 [[D_VAL]] to i32
+; CHECK-NEXT: [[MUL_AC:%.*]] = mul nsw i32 [[B_EXT]], [[C_EXT]]
+; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[RES]], [[MUL_AC]]
+; CHECK-NEXT: [[MUL_DB:%.*]] = mul nsw i32 [[B_EXT]], [[D_EXT]]
+; CHECK-NEXT: [[SUB]] = sub i32 [[ADD]], [[MUL_DB]]
+; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
+; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]], !loop [[META4:![0-9]+]]
+;
+entry:
+ %cmp28.not = icmp ult i32 %N, 2
+ %div27 = lshr i32 %N, 1
+ %wide.trip.count = zext nneg i32 %div27 to i64
+ br label %for.body
+
+for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry
+ %res.0.lcssa = phi i32 [ %sub, %for.body ]
+ ret i32 %res.0.lcssa
+
+for.body: ; preds = %for.body.preheader, %for.body
+ %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+ %res = phi i32 [ 0, %entry ], [ %sub, %for.body ]
+ %a.ptr = getelementptr inbounds nuw i8, ptr %a, i64 %indvars.iv
+ %b.ptr = getelementptr inbounds nuw i8, ptr %b, i64 %indvars.iv
+ %c.ptr = getelementptr inbounds nuw i8, ptr %c, i64 %indvars.iv
+ %a.val = load i8, ptr %a.ptr, align 1
+ %b.val = load i8, ptr %b.ptr, align 1
+ %c.val = load i8, ptr %c.ptr, align 1
+ %a.ext = sext i8 %a.val to i32
+ %b.ext = sext i8 %b.val to i32
+ %c.ext = sext i8 %c.val to i32
+ %mul.ab = mul nsw i32 %a.ext, %b.ext
+ %add = add nsw i32 %res, %mul.ab
+ %mul.ac = mul nsw i32 %a.ext, %c.ext
+ %sub = sub i32 %add, %mul.ac
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
+}
+
+define i32 @chained_partial_reduce_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-LABEL: define i32 @chained_partial_reduce_add_add(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP8]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 16 x i8>, ptr [[TMP13]], align 1
+; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP9]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD6:%.*]] = load <vscale x 16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT: [[TMP25:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP18:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD3]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP27:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD6]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP29:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP18]]
+; CHECK-NEXT: [[PARTIAL_REDUCE7:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI1]], <vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT: [[TMP31:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP27]]
+; CHECK-NEXT: [[PARTIAL_REDUCE9]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE7]], <vscale x 16 x i32> [[TMP31]])
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT: [[TMP21:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP21]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[TMP22:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE9]])
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP22]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
+; CHECK: for.cond.cleanup:
+; CHECK-NEXT: [[RES_0_LCSSA:%.*]] = phi i32 [ [[SUB:%.*]], [[FOR_BODY]] ], [ [[TMP22]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: ret i32 [[RES_0_LCSSA]]
+; C...
[truncated]
|
@llvm/pr-subscribers-llvm-transforms Author: Nicholas Guy (NickGuy-Arm) ChangesChaining partial reductions, where multiple partial reductions share an accumulator, allow for more values to be combined together as part of the reduction without discarding the semantics of the partial reduction itself. Patch is 45.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120272.diff 6 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 8e7e590c173ff2..c6cebcca679353 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -368,7 +368,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);
- if (Opcode != Instruction::Add)
+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
return Invalid;
if (InputTypeA != InputTypeB)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 99f6a8860f0f4d..a18cf3c9bec2bd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8790,12 +8790,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
/// are valid so recipes can be formed later.
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// Find all possible partial reductions.
- SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
+ SmallVector<std::pair<PartialReductionChain, unsigned>>
PartialReductionChains;
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
- if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
- getScaledReduction(Phi, RdxDesc, Range))
- PartialReductionChains.push_back(*Pair);
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
+ if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
+ PartialReductionChains.append(*SR);
+ }
// A partial reduction is invalid if any of its extends are used by
// something that isn't another partial reduction. This is because the
@@ -8823,26 +8823,41 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
}
}
-std::optional<std::pair<PartialReductionChain, unsigned>>
-VPRecipeBuilder::getScaledReduction(PHINode *PHI,
- const RecurrenceDescriptor &Rdx,
+std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
+VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
VFRange &Range) {
+
+ if (!CM.TheLoop->contains(RdxExitInstr))
+ return std::nullopt;
+
// TODO: Allow scaling reductions when predicating. The select at
// the end of the loop chooses between the phi value and most recent
// reduction result, both of which have different VFs to the active lane
// mask when scaling.
- if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
+ if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
return std::nullopt;
- auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
if (!Update)
return std::nullopt;
Value *Op = Update->getOperand(0);
Value *PhiOp = Update->getOperand(1);
- if (Op == PHI) {
- Op = Update->getOperand(1);
- PhiOp = Update->getOperand(0);
+ if (Op == PHI)
+ std::swap(Op, PhiOp);
+
+ SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
+
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
+ if (auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
+ Chains.append(*SR0);
+ PHI = SR0->rbegin()->first.Reduction;
+
+ Op = Update->getOperand(0);
+ PhiOp = Update->getOperand(1);
+ if (Op == PHI)
+ std::swap(Op, PhiOp);
+ }
}
if (PhiOp != PHI)
return std::nullopt;
@@ -8860,12 +8875,16 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+ // Check that the extends extend from the same type.
+ if (A->getType() != B->getType())
+ return std::nullopt;
+
TTI::PartialReductionExtendKind OpAExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
TTI::PartialReductionExtendKind OpBExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
- PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
+ PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8880,9 +8899,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
return Cost.isValid();
},
Range))
- return std::make_pair(Chain, TargetScaleFactor);
+ Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
- return std::nullopt;
+ return Chains;
}
VPRecipeBase *
@@ -8979,7 +8998,9 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
VPValue *BinOp = Operands[0];
VPValue *Phi = Operands[1];
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
+ isa<VPPartialReductionRecipe>(BinOpRecipe))
std::swap(BinOp, Phi);
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index cf653e2d3e6584..6be9a716cacbfd 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -144,8 +144,8 @@ class VPRecipeBuilder {
/// Returns null if no scaled reduction was found, otherwise a pair with a
/// struct containing reduction information and the scaling factor between the
/// number of elements in the input and output.
- std::optional<std::pair<PartialReductionChain, unsigned>>
- getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
+ std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
+ getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
VFRange &Range);
public:
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 87f87bf1437196..9e09bfc62105c1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2453,13 +2453,16 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
Opcode(Opcode) {
- assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
+ auto *DefiningRecipe = getOperand(1)->getDefiningRecipe();
+ assert((isa<VPReductionPHIRecipe>(DefiningRecipe) ||
+ isa<VPPartialReductionRecipe>(DefiningRecipe)) &&
"Unexpected operand order for partial reduction recipe");
}
~VPPartialReductionRecipe() override = default;
VPPartialReductionRecipe *clone() override {
- return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1));
+ return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1),
+ getUnderlyingInstr());
}
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 979a8e0768a991..668c317033fe78 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -317,13 +317,22 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
State.setDebugLocFrom(getDebugLoc());
auto &Builder = State.Builder;
- assert(getOpcode() == Instruction::Add &&
- "Unhandled partial reduction opcode");
-
Value *BinOpVal = State.get(getOperand(0));
Value *PhiVal = State.get(getOperand(1));
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
+ auto Opcode = getOpcode();
+
+ // Currently we don't have a partial_reduce_sub intrinsic,
+ // so mimic the behaviour by negating the second operand
+ if (Opcode == Instruction::Sub) {
+ BinOpVal = Builder.CreateSub(Constant::getNullValue(BinOpVal->getType()),
+ BinOpVal);
+ Opcode = Instruction::Add;
+ }
+
+ assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
+
Type *RetTy = PhiVal->getType();
CallInst *V = Builder.CreateIntrinsic(
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-chained.ll b/llvm/test/CodeGen/AArch64/partial-reduction-chained.ll
new file mode 100644
index 00000000000000..4272e2f7552495
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-chained.ll
@@ -0,0 +1,568 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -o - %s --passes=loop-vectorize -force-vector-width=16 -force-vector-interleave=1 -scalable-vectorization=on | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: read) vscale_range(1,16)
+define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-LABEL: define i32 @chained_partial_reduce_add_sub(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP8]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 16 x i8>, ptr [[TMP13]], align 1
+; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP9]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD6:%.*]] = load <vscale x 16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT: [[TMP25:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP18:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD3]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP27:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD6]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP29:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP18]]
+; CHECK-NEXT: [[PARTIAL_REDUCE7:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI1]], <vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT: [[TMP31:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP27]]
+; CHECK-NEXT: [[TMP33:%.*]] = sub <vscale x 16 x i32> zeroinitializer, [[TMP31]]
+; CHECK-NEXT: [[PARTIAL_REDUCE9]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE7]], <vscale x 16 x i32> [[TMP33]])
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE9]])
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP23]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
+; CHECK: for.cond.cleanup:
+; CHECK-NEXT: [[RES_0_LCSSA:%.*]] = phi i32 [ [[SUB:%.*]], [[FOR_BODY]] ], [ [[TMP23]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: ret i32 [[RES_0_LCSSA]]
+; CHECK: for.body:
+; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT: [[RES:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[SUB]], [[FOR_BODY]] ]
+; CHECK-NEXT: [[A_PTR:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT: [[B_PTR:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT: [[C_PTR:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDVARS_IV]]
+; CHECK-NEXT: [[B_VAL:%.*]] = load i8, ptr [[A_PTR]], align 1
+; CHECK-NEXT: [[C_VAL:%.*]] = load i8, ptr [[B_PTR]], align 1
+; CHECK-NEXT: [[D_VAL:%.*]] = load i8, ptr [[C_PTR]], align 1
+; CHECK-NEXT: [[B_EXT:%.*]] = sext i8 [[B_VAL]] to i32
+; CHECK-NEXT: [[C_EXT:%.*]] = sext i8 [[C_VAL]] to i32
+; CHECK-NEXT: [[D_EXT:%.*]] = sext i8 [[D_VAL]] to i32
+; CHECK-NEXT: [[MUL_AC:%.*]] = mul nsw i32 [[B_EXT]], [[C_EXT]]
+; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[RES]], [[MUL_AC]]
+; CHECK-NEXT: [[MUL_DB:%.*]] = mul nsw i32 [[B_EXT]], [[D_EXT]]
+; CHECK-NEXT: [[SUB]] = sub i32 [[ADD]], [[MUL_DB]]
+; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
+; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]], !loop [[META4:![0-9]+]]
+;
+entry:
+ %cmp28.not = icmp ult i32 %N, 2
+ %div27 = lshr i32 %N, 1
+ %wide.trip.count = zext nneg i32 %div27 to i64
+ br label %for.body
+
+for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry
+ %res.0.lcssa = phi i32 [ %sub, %for.body ]
+ ret i32 %res.0.lcssa
+
+for.body: ; preds = %for.body.preheader, %for.body
+ %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+ %res = phi i32 [ 0, %entry ], [ %sub, %for.body ]
+ %a.ptr = getelementptr inbounds nuw i8, ptr %a, i64 %indvars.iv
+ %b.ptr = getelementptr inbounds nuw i8, ptr %b, i64 %indvars.iv
+ %c.ptr = getelementptr inbounds nuw i8, ptr %c, i64 %indvars.iv
+ %a.val = load i8, ptr %a.ptr, align 1
+ %b.val = load i8, ptr %b.ptr, align 1
+ %c.val = load i8, ptr %c.ptr, align 1
+ %a.ext = sext i8 %a.val to i32
+ %b.ext = sext i8 %b.val to i32
+ %c.ext = sext i8 %c.val to i32
+ %mul.ab = mul nsw i32 %a.ext, %b.ext
+ %add = add nsw i32 %res, %mul.ab
+ %mul.ac = mul nsw i32 %a.ext, %c.ext
+ %sub = sub i32 %add, %mul.ac
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
+}
+
+define i32 @chained_partial_reduce_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-LABEL: define i32 @chained_partial_reduce_add_add(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[TMP6]]
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP8]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 16 x i8>, ptr [[TMP13]], align 1
+; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP9]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD6:%.*]] = load <vscale x 16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT: [[TMP25:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP18:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD3]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP27:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD6]] to <vscale x 16 x i32>
+; CHECK-NEXT: [[TMP29:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP18]]
+; CHECK-NEXT: [[PARTIAL_REDUCE7:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI1]], <vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT: [[TMP31:%.*]] = mul nsw <vscale x 16 x i32> [[TMP25]], [[TMP27]]
+; CHECK-NEXT: [[PARTIAL_REDUCE9]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE7]], <vscale x 16 x i32> [[TMP31]])
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT: [[TMP21:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP21]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[TMP22:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE9]])
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP22]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
+; CHECK: for.cond.cleanup:
+; CHECK-NEXT: [[RES_0_LCSSA:%.*]] = phi i32 [ [[SUB:%.*]], [[FOR_BODY]] ], [ [[TMP22]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: ret i32 [[RES_0_LCSSA]]
+; C...
[truncated]
|
Sounds like supporting partial reduction subtractions would be a good extension on its own (and test-able independently), could it be split off? |
Yep. I've pulled it out completely, and will create a new PR for adding sub support. Edit; PR for adding subtraction partial reductions: #123636 |
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a couple of ideas.
@@ -8979,7 +8997,9 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction, | |||
|
|||
VPValue *BinOp = Operands[0]; | |||
VPValue *Phi = Operands[1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be called Accumulator
or something like that since that's now more accurate.
@@ -2453,7 +2453,9 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe { | |||
: VPSingleDefRecipe(VPDef::VPPartialReductionSC, | |||
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst), | |||
Opcode(Opcode) { | |||
assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) && | |||
auto *DefiningRecipe = getOperand(1)->getDefiningRecipe(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps AccumulatorRecipe
would be more descriptive.
Ok great thanks! It’s a deft at the moment, any issues? |
Only when isolating/creating the tests. As far as I can tell, subtraction partial reductions will rarely, if ever, be provided as an input to vplan (due to no |
CC @thurstond https://lab.llvm.org/buildbot/#/builders/24/builds/4540/steps/12/logs/stdio
|
@NickGuy-Arm Reverted #124198 |
…120272" (#124282) Change `getScaledReduction` to take an existing vector, rather than creating and returning a new one each call. Rename `getScaledReduction` to `getScaledReductions` to more accurately reflect what it's now doing. --------- Co-authored-by: Karlo Basioli <[email protected]>
Chaining partial reductions, where multiple partial reductions share an accumulator, allow for more values to be combined together as part of the reduction without discarding the semantics of the partial reduction itself.