-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[LoopVectorizer][AArch64] Add support for partial reduce subtraction #123636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Having trouble isolating the tests, I believe due to the loop vectorizer not currently supporting subtraction recurrences. But I wanted to get the code itself visible even if the tests aren't ready, hence the draft. |
8b8443a
to
26df370
Compare
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-vectorizers Author: Nicholas Guy (NickGuy-Arm) ChangesPatch is 38.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123636.diff 6 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index aae2fdaf5bec37..9111efb9bb4e90 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4673,7 +4673,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);
- if (Opcode != Instruction::Add)
+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
return Invalid;
if (InputTypeA != InputTypeB)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 57b7358049bcef..c195230e8dcb32 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8697,8 +8697,9 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// Build up a set of partial reduction bin ops for efficient use checking.
SmallSet<User *, 4> PartialReductionBinOps;
- for (const auto &[PartialRdx, _] : PartialReductionChains)
+ for (const auto &[PartialRdx, _] : PartialReductionChains) {
PartialReductionBinOps.insert(PartialRdx.BinOp);
+ }
auto ExtendIsOnlyUsedByPartialReductions =
[&PartialReductionBinOps](Instruction *Extend) {
@@ -8761,20 +8762,23 @@ bool VPRecipeBuilder::getScaledReductions(
return false;
using namespace llvm::PatternMatch;
+ BinaryOperator *ExtendedBinOp = BinOp;
+ match(BinOp, m_Neg(m_BinOp(ExtendedBinOp)));
+
Value *A, *B;
- if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
- !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
+ if (!match(ExtendedBinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
+ !match(ExtendedBinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
return false;
- Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
- Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+ Instruction *ExtA = cast<Instruction>(ExtendedBinOp->getOperand(0));
+ Instruction *ExtB = cast<Instruction>(ExtendedBinOp->getOperand(1));
TTI::PartialReductionExtendKind OpAExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
TTI::PartialReductionExtendKind OpBExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
- PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+ PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, ExtendedBinOp);
unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8785,7 +8789,7 @@ bool VPRecipeBuilder::getScaledReductions(
InstructionCost Cost = TTI->getPartialReductionCost(
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
VF, OpAExtend, OpBExtend,
- std::make_optional(BinOp->getOpcode()));
+ std::make_optional(ExtendedBinOp->getOpcode()));
return Cost.isValid();
},
Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2679ed6b26b5d1..eed2476f3770ce 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -25,6 +25,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/VectorBuilder.h"
@@ -282,7 +283,20 @@ InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
std::optional<unsigned> Opcode = std::nullopt;
- VPRecipeBase *BinOpR = getOperand(0)->getDefiningRecipe();
+ VPValue *BinOp = getOperand(0);
+ VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
+
+ using namespace llvm::PatternMatch;
+ if (auto *UnderInst =
+ dyn_cast_if_present<Instruction>(BinOp->getUnderlyingValue())) {
+ if (match(UnderInst, m_Neg(m_BinOp()))) {
+ BinOpR = BinOpR->getOperand(1)->getDefiningRecipe();
+ }
+ }
+ // BinOp is never used again, any further interaction should be via the
+ // defining recipe `BinOpR`
+ BinOp = nullptr;
+
if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
Opcode = std::make_optional(WidenR->getOpcode());
@@ -318,13 +332,20 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
State.setDebugLocFrom(getDebugLoc());
auto &Builder = State.Builder;
- assert(getOpcode() == Instruction::Add &&
- "Unhandled partial reduction opcode");
-
Value *BinOpVal = State.get(getOperand(0));
Value *PhiVal = State.get(getOperand(1));
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
+ unsigned Opcode = getOpcode();
+
+ if (Opcode == Instruction::Sub) {
+ bool HasNSW = cast<Instruction>(BinOpVal)->hasNoSignedWrap();
+ BinOpVal = Builder.CreateNeg(BinOpVal, "", HasNSW);
+ Opcode = Instruction::Add;
+ }
+
+ assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
+
Type *RetTy = PhiVal->getType();
CallInst *V = Builder.CreateIntrinsic(
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index bedf8b6b3a9b56..2dc515649b2262 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -22,7 +22,7 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-NEON: vector.body:
; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEON-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 0
; CHECK-NEON-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP0]]
; CHECK-NEON-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP0]]
@@ -37,14 +37,15 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT: [[TMP11:%.*]] = add <16 x i32> [[VEC_PHI]], [[TMP10]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
; CHECK-NEON-NEXT: [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP9]]
-; CHECK-NEON-NEXT: [[TMP13]] = sub <16 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEON-NEXT: [[TMP13:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP12]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP13]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-NEON: middle.block:
-; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
;
@@ -114,7 +115,7 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-SVE-MAXBW: vector.body:
; CHECK-SVE-MAXBW-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
; CHECK-SVE-MAXBW-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
; CHECK-SVE-MAXBW-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
; CHECK-SVE-MAXBW-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
@@ -129,14 +130,15 @@ define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP15:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP16]])
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP19]] = sub <vscale x 8 x i32> [[TMP17]], [[TMP18]]
+; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP18]]
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP19]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
; CHECK-SVE-MAXBW-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
-; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP19]])
+; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
;
@@ -350,7 +352,7 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-NEON: vector.body:
; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEON-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 0
; CHECK-NEON-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP0]]
; CHECK-NEON-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP0]]
@@ -365,14 +367,15 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub <16 x i32> [[VEC_PHI]], [[TMP10]]
+; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP10]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
; CHECK-NEON-NEXT: [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP9]]
-; CHECK-NEON-NEXT: [[TMP13]] = add <16 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP12]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK-NEON: middle.block:
-; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
;
@@ -442,7 +445,7 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-SVE-MAXBW: vector.body:
; CHECK-SVE-MAXBW-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
; CHECK-SVE-MAXBW-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
; CHECK-SVE-MAXBW-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
; CHECK-SVE-MAXBW-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
@@ -457,14 +460,15 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP15:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub <vscale x 8 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]])
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP19]] = add <vscale x 8 x i32> [[TMP17]], [[TMP18]]
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP18]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
; CHECK-SVE-MAXBW-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
-; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP19]])
+; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
;
@@ -516,7 +520,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-NEON: vector.body:
; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEON-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 0
; CHECK-NEON-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP0]]
; CHECK-NEON-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP0]]
@@ -531,14 +535,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub <16 x i32> [[VEC_PHI]], [[TMP10]]
+; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP10]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
; CHECK-NEON-NEXT: [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP9]]
-; CHECK-NEON-NEXT: [[TMP13]] = sub <16 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEON-NEXT: [[TMP13:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP12]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP13]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-NEON: middle.block:
-; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
;
@@ -608,7 +614,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-SVE-MAXBW: vector.body:
; CHECK-SVE-MAXBW-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
; CHECK-SVE-MAXBW-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
; CHECK-SVE-MAXBW-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
; CHECK-SVE-MAXBW-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
@@ -623,14 +629,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP15:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub <vscale x 8 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]])
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-MAXBW-NE...
[truncated]
|
for (const auto &[PartialRdx, _] : PartialReductionChains) { | ||
PartialReductionBinOps.insert(PartialRdx.BinOp); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, I needed the braces for some extra debug output. Removed
BinaryOperator *ExtendedBinOp = BinOp; | ||
match(BinOp, m_Neg(m_BinOp(ExtendedBinOp))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this just set BinOp
inside the match instead of creating a new variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another remnant of a previous iteration. We did use some info from BinOp
(I think initially it was for costing), but now we're costing from ExtendedBinOp
. Removed the new variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, with a couple of suggestions.
using namespace llvm::PatternMatch; | ||
if (auto *UnderInst = | ||
dyn_cast_if_present<Instruction>(BinOp->getUnderlyingValue())) { | ||
if (match(UnderInst, m_Neg(m_BinOp()))) { | ||
BinOpR = BinOpR->getOperand(1)->getDefiningRecipe(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment explaining why this has to get the 2nd operand of the BinOp would be good.
// BinOp is never used again, any further interaction should be via the | ||
// defining recipe `BinOpR` | ||
BinOp = nullptr; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels a bit strange to set something to null like this, perhaps it should just be left alone. Or we could just not declare the variable in the first place and use getOperand(0)
instead.
if (auto *UnderInst = | ||
dyn_cast_if_present<Instruction>(BinOp->getUnderlyingValue())) { | ||
if (match(UnderInst, m_Neg(m_BinOp()))) { | ||
BinOpR = BinOpR->getOperand(1)->getDefiningRecipe(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we match this on the VPValues/VPRecipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can indeed. Done.
@@ -2214,7 +2214,7 @@ define i32 @not_dotp_extend_user(ptr %a, ptr %b) #0 { | |||
; CHECK-MAXBW-NEXT: br label [[VECTOR_BODY:%.*]] | |||
; CHECK-MAXBW: vector.body: | |||
; CHECK-MAXBW-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] | |||
; CHECK-MAXBW-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP24:%.*]], [[VECTOR_BODY]] ] | |||
; CHECK-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seemingly so, not sure why the update script changed only these. Reverted as there are no other changes in this file.
if (Opcode == Instruction::Sub) { | ||
bool HasNSW = cast<Instruction>(BinOpVal)->hasNoSignedWrap(); | ||
BinOpVal = Builder.CreateNeg(BinOpVal, "", HasNSW); | ||
Opcode = Instruction::Add; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operands are always VPWidenRecipe's, correct?
Could we instead create adjust the input recipes to first negate the operand and use Add instead of Sub for the partial reduction recipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are always Widen recipes as far as I can tell, however the incoming IR doesn't always match the pattern add(<a>, neg(<b>))
. In the case of complex dot products, the second sub is explicitly represented as a sub, we then transform that here to follow the aformentioned pattern.
Unless I'm missing something, there doesn't seem to be a method of creating multiple VPRecipes from a single source instruction, which would be required for this suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general it should be possible to create multiple recipes for a single source instruction (we do it in multiple places already IIRC), as long as the additional instructions do not need to be mapped to IR instructions for lookup later.
A recent example is https://github.com/llvm/llvm-project/pull/124268/files#diff-da321d454a7246f8ae276bf1db2782bf26b5210b8133cb59e4d7fd45d0905decR8900
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, though it caused the cost model to evaluate the vplan as being more expensive, so doesn't emit scalable vectors by default in this case. (I believe it's a similar case to why -vectorizer-maximize-bandwidth
is now a thing, but even with that fixed-width plans are chosen over scalable plans).
A workaround for this is to use the opt arguments -force-vector-width=8/16 -scalable-vectorization=preferred
in tandem (or in C++ by using a loop pragma vectorize_width(8/16, scalable)
) to force a VF that is supported by partial reductions.
auto *const Zero = ConstantInt::get(Reduction->getType(), 0); | ||
SmallVector<VPValue *, 2> Ops; | ||
Ops.push_back(Plan.getOrAddLiveIn(Zero)); | ||
Ops.push_back(cast<VPWidenRecipe>(BinOp->getDefiningRecipe())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to cast? Can it just add BinOp
to the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was initially using the recipe, and after a few iterations got rid of the direct recipe usage. This stuck around as I didn't register at the time that BinOp
in this context is the same as the recipe. Will be fixed in the next commit (assuming no test failures)
@@ -632,7 +632,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 { | |||
; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]] | |||
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]]) | |||
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]] | |||
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP18]] | |||
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = sub <vscale x 8 x i32> zeroinitializer, [[TMP18]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the flags be discarded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, they're not being discarded so much as they simply don't exist. This is the sub that is being created by matching an existing sub instruction and replacing it with add(%a, sub(0, %b))
. I'm also not sure that the flags will do anything in this case, as the only time it would overflow is if TMP18 had already overflowed, but in that case would be a poison value.
I'm happy to add them if they're deemed necessary/helpful though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correction: We do already propagate the flags, however in this case the flags are missing due to not being on the respective sub
instruction in the source IR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Instead of implementing a new intrinsic for subtracting partial reductions, generate a negation instruction for the second operand of the partial reduction.
7e3a8e6
to
3757dd5
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
VPBasicBlock *ParentBlock = Builder.getInsertBlock(); | ||
if (!ParentBlock) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the insert block ever be null? I'd expect it to always be set at this point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't seen any case where it would be, so this was just a bit of defensiveness for something I wasn't 100% sure on. I can switch it to an assert, or remove it altogether if you think the check is irrelevant.
Edit: Looks like VPBuilder::clearInsertPoint
can result in getInsertPoint
returning a nullptr. So I've changed it to an assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@@ -8800,6 +8800,10 @@ bool VPRecipeBuilder::getScaledReductions( | |||
return false; | |||
|
|||
using namespace llvm::PatternMatch; | |||
// Use the side-effect of match to replace BinOp only if the pattern is | |||
// matched, we don't care at this point whether it actually matched. | |||
match(BinOp, m_Neg(m_BinOp(BinOp))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand how this works. We've already bailed out above if BinOp
is not a BinaryOperator, which surely means that it cannot simultaneously be a unary operator, which is what m_Neg represents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I realise now that m_Neg
is actually matching a sub 0, %x
operation. Please ignore the noise!
No description provided.